# BIKE

In [1]:
# imports
from random import randbytes
from Crypto.Hash import SHAKE256, SHA3_384
from Crypto.Cipher import AES
from functools import reduce, wraps
from tqdm import tqdm

import time, os
import numpy as np

In [2]:
# randoms

# PRNG based on SP-900 80A NIST standard
class RNG:
    # seed is 48 bytes
    def __init__(self, seed: bytes):
        self.seed = None
        self.reseed_counter = 0
        self.Key = None
        self.V = None
        self.random_init(seed)
        
    def random_init(self, seed: bytes):
        self.seed = bytearray(seed)
        self.Key = bytearray(b'\0' * 32)
        self.V  = bytearray(b'\0' * 16)
        self.aes256_ctr_drbg_update(self.seed)
        self.reseed_counter = 1
    
    def random_bytes(self, length: int) -> bytes:
        ct = bytearray()
        cipher = AES.new(bytes(self.Key), AES.MODE_ECB)
        
        for i in range(0, length, 16):
            for j in reversed(range(16)):
                if self.V[j] == 0xff:
                    self.V[j] = 0
                else:
                    self.V[j] += 1
                    break
            ct += cipher.encrypt(bytes(self.V))
        
        self.aes256_ctr_drbg_update(None)
        self.reseed_counter += 1
        
        return bytes(ct[:length])
    
    def aes256_ctr_drbg_update(self, data: bytes):
        temp = bytearray()
        cipher = AES.new(bytes(self.Key), AES.MODE_ECB)
        
        for i in range(3):
            for j in reversed(range(16)):
                if self.V[j] == 0xff:
                    self.V[j] = 0
                else:
                    self.V[j] += 1
                    break
            temp += cipher.encrypt(bytes(self.V))

        if data != None:
            temp = [i ^^ j for i, j in zip(temp, data)] # xor over data
        
        self.Key = temp[:32]
        self.V  = temp[32:]
        
def getStream(seed: bytes) -> bytes:
    h = SHAKE256.new()
    h.update(seed)
    return h

# algorithm 3 -> (length, wt) = (r, w/2)
def WSHAKE256_PRF(stream, length: int, wt: int) -> list:
    wlist = 0
    
    for i in reversed(range(wt)):
        randpos = int.from_bytes(stream.read(4), 'little') # get 32 bit integer
        pos = i + (((length - i) * randpos) >> 32)
        wlist |= 1 << (i if ((wlist >> pos) & 1) else pos)
    return wlist

In [3]:
# polynomial helpers

def int_to_poly(F: PolynomialQuotientRing, num: int):
    num = list(f"{num:0{r}b}"[::-1]) # bits
    return F(list(map(int, num)))

def poly_to_int(F: PolynomialQuotientRing, f) -> int:
    num = f.list()[::-1]
    num = "".join([str(i) for i in num])
    return int(num, 2)

def timeit(func):
    @wraps(func)
    def timeit_wrapper(*args, **kwargs):
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        total_time = end_time - start_time
        print(f'Function {func.__name__} {kwargs} Took {total_time:.8f} seconds')
        return result
    return timeit_wrapper

In [4]:
# BIKE helpers

def functionH(m: bytes) -> (int, int):
    stream = getStream(m)
    e = WSHAKE256_PRF(stream, r * 2, t)
    return e & ((1<<r) - 1), e >> r, 

def functionL(e0: int, e1: int) -> bytes:
    e  = int.to_bytes(int(e0), ceil(r / 8), 'little')
    e += int.to_bytes(int(e1), ceil(r / 8), 'little')
    h = SHA3_384.new()
    h.update(e)
    return h.digest()[:32]

def functionK(m: bytes, c0: bytes, c1: bytes) -> bytes:
    h = SHA3_384.new()
    h.update(m + c0 + c1)
    return h.digest()[:32]

def compact(data):
    data = int.to_bytes(int(data), ceil(r / 8), 'little')
    out = []
    for i in range(ceil(r / 8)):
        for j in range(8):
            if i * 8 + j == r:
                break
            if (data[i] >> j) & 1:
                out.append(i * 8 + j)
    return out

In [5]:
# Decoder helpers

# transposes the row to a column
def getCol(row: list) -> list:
        col = np.zeros(w // 2, dtype = int)
        if row[0] == 0:
            col[0] = 0
            for i in range(1, w//2):
                col[i] = r - row[w//2 - i]
        else:
            for i in range(w/2):
                col[i] = r - row[w//2 - 1 - i]
        return col

# counts the number of nonzero positions
def getHammingWeight(s: list) -> int:
    return np.count_nonzero(s)

# number of unsatisfied parity checks
# slowest operation (this is effectively called r * w times)
def counter(h_col: list, pos: int, syn: list) -> int:
    return np.count_nonzero(syn[(h_col + pos) % r])

In [6]:
# BGF Decoder
class BGF:
    def __init__(self, syndrome, h0_compact, h1_compact):
        self.syndrome       = np.array(syndrome, dtype = int)
        self.h0_compact     = np.array(h0_compact, dtype = int)
        self.h1_compact     = np.array(h1_compact, dtype = int)
        self.h0_compact_col = getCol(h0_compact)
        self.h1_compact_col = getCol(h1_compact)
        
        self.e     = np.zeros(r * 2, dtype = int)
        self.black = np.zeros(r * 2, dtype = int)
        self.gray  = np.zeros(r * 2, dtype = int)
    
    def flipAdjustedErrorPos(self, pos: int):
        apos = pos
        if pos != 0 and pos != r:
            if pos > r:
                apos = (2 * r - pos) + r
            else:
                apos = r - pos
        self.e[apos] ^^= 1
  
    def recompute_syndrome(self, pos: int):
        if pos < r:
            self.syndrome[(pos - self.h0_compact) % r] ^^= 1
        else:
            self.syndrome[(pos - self.h1_compact) % r] ^^= 1

    def BFIter(self, T: int):
        pos = [0] * r * 2
        
        for j in range(r):
            cnt = counter(self.h0_compact_col, j, self.syndrome)
            if cnt >= T:
                self.flipAdjustedErrorPos(j)
                pos[j] = 1
                self.black[j] = 1
            elif cnt >= (T - tau):
                self.gray[j] = 1
        
        for j in range(r):
            cnt = counter(self.h1_compact_col, j, self.syndrome)
            if cnt >= T:
                self.flipAdjustedErrorPos(j + r)
                pos[j + r] = 1
                self.black[j + r] = 1
            elif cnt >= (T - tau):
                self.gray[j + r] = 1
        
        for j in range(r * 2):
            if pos[j] == 1:
                self.recompute_syndrome(j)
                
    def BFMaskedIter(self, T: int, mask: list):
        pos = [0] * r * 2
        
        for j in range(r):
            cnt = counter(self.h0_compact_col, j, self.syndrome)
            if cnt >= T and mask[j]:
                self.flipAdjustedErrorPos(j)
                pos[j] = 1
        
        for j in range(r):
            cnt = counter(self.h1_compact_col, j, self.syndrome)
            if cnt >= T and mask[j + r]:
                self.flipAdjustedErrorPos(j + r)
                pos[j + r] = 1
        
        for j in range(r * 2):
            if pos[j] == 1:
                self.recompute_syndrome(j)

    def Decode(self):
        for i in range(NbIter):
            self.black = [0] * r * 2
            self.gray  = [0] * r * 2
            
            T = floor(threshold(getHammingWeight(self.syndrome)))
            self.BFIter(T)
            
            if i == 0:
                self.BFMaskedIter(ceil(w/4) + 1, self.black)
                self.BFMaskedIter(ceil(w/4) + 1, self.gray)
        
        return getHammingWeight(self.syndrome)

In [7]:
# BIKE class -> (keygen, encap, decap)
class BIKE:
    def __init__(self, seed):
        self.seed = seed
        self.rand = RNG(self.seed)
        
        # Polynomial Ring
        FF.<x> = GF(2)[]
        R.<x> = PolynomialRing(GF(2), 'x').quotient(x^r - 1)
        self.R = R
        
        # stored in little endian format
        self.pk = None
        self.sk = None
        self.error = None
        self.m = None
        self.ct = None
        self.ss = None

    def getSeeds(self):
        s = self.rand.random_bytes(64)
        return s[:32], s[32:]

    def crypto_kem_keypair(self):
        s1, s2 = self.getSeeds()
    
        stream = getStream(s1)
        h0 = WSHAKE256_PRF(stream, r, w // 2)
        h1 = WSHAKE256_PRF(stream, r, w // 2)
        sigma = s2
        
        self.sk = (h0, h1, sigma)
        
        h0 = int_to_poly(self.R, h0)
        h1 = int_to_poly(self.R, h1)
        h = h1 * h0^-1
        
        self.pk = poly_to_int(R, h)
        
    def crypto_kem_enc(self):
        m, _ = self.getSeeds()
        self.m = m
        
        e0, e1 = functionH(m)
        self.error = (e0, e1)
        
        c1 = functionL(e0, e1)
        c1 = bytes([i ^^ j for i,j in zip(c1, m)])

        e0 = int_to_poly(self.R, e0)
        e1 = int_to_poly(self.R, e1)
        
        h = int_to_poly(self.R, self.pk)
        c0 = e0 + e1 * h
        c0 = poly_to_int(self.R, c0)
        c0 = int.to_bytes(int(c0), ceil(r / 8), 'little')
        self.ct = (c0, c1)
        
        self.ss = functionK(m, c0, c1)
    
    # decap might not always be successful
    def crypto_kem_dec(self):
        h0c = compact(self.sk[0])
        h1c = compact(self.sk[1])
        
        self.compact = (h0c, h1c)
        
        # syndrome
        c0 = int.from_bytes(self.ct[0], 'little')
        c0 = int_to_poly(self.R, c0)
        h0 = int_to_poly(self.R, self.sk[0])
        syn = c0 * h0
        
        # transposing row to column
        syn = syn.list()
        syn = [syn[0]] + syn[1:][::-1]
        syn = list(map(int, syn))
        self.syndrome = syn.copy()
        
        bgf = BGF(syn, h0c, h1c)
        bgf.Decode()
        e = bgf.e
        
        e0, e1 = e[:r][::-1], e[r:][::-1]
        e0 = int(''.join([str(i) for i in e0]), 2)
        e1 = int(''.join([str(i) for i in e1]), 2)
        
        c1 = functionL(e0, e1)
        m = bytes([i^^j for i,j in zip(c1, self.ct[1])])
        
        assert m == self.m
        assert functionK(m, self.ct[0], self.ct[1]) == self.ss

In [8]:
'''
level -> level of security
r     -> block length (is prime and has 2 as a primitive)
w     -> row weight (w/2 must be odd)
t     -> error weight
l     -> shared secret size (fixed for all implementations)

Decoder Params
NbIter    -> no of iterations the decoder runs
threshold -> threshold selection rule (depends on syndrome)
tau       -> threshold gap to determine size of gray set

file      -> KAT file for testing
'''
# (level, r, w, t, l, NbIter, threshold, tau, kat_file)
security = {"level1": (1, 12323, 142, 134, 256, 5, lambda x: max(13.530 + 0.0069722 * x, 36), 3, "KAT/PQCkemKAT_BIKE_3114.rsp"),
            "level3": (3, 24659, 206, 199, 256, 5, lambda x: max(15.2588 + 0.005265 * x, 52), 3, "KAT/PQCkemKAT_BIKE_6198.rsp"),
            "level5": (5, 40973, 274, 264, 256, 5, lambda x: max(17.8785 + 0.00402312 * x, 69), 3, "KAT/PQCkemKAT_BIKE_10276.rsp")
           }

In [9]:
from tqdm import tqdm

def read_test_file(file: str) -> list:
    with open(file) as f:
        tests = f.read().split('\n\n')[1:-1]
        f.close()
    tests = [dict(tuple(data.split(' = ')) for data in test.split('\n')) for test in tests]
    return tests

def unit_test(test):
    seed = bytes.fromhex(test["seed"])
    bike = BIKE(seed)
    bike.crypto_kem_keypair()
    bike.crypto_kem_enc()
    bike.crypto_kem_dec()
    return bike.ss.hex() == test["ss"].lower()

In [10]:
level, r, w, t, l, NbIter, threshold, tau, kat_file = security["level5"]
tests = read_test_file(kat_file)[:5]

In [11]:
failed = []
pbar = tqdm(tests, desc = "Testing")
for test in pbar:
    if not unit_test(test):
        failed.append(test["count"])
    pbar.set_postfix_str(f"Fail - {len(failed)}")

Testing: 100%|██████████| 5/5 [00:43<00:00,  8.77s/it, Fail - 0]


In [12]:
# manual testing

# tweak params as you like
r = 12323
w = 138
t = 100
l = 256
NbIter = 5
threshold = lambda x: max(13.530 + 0.0069722 * x, 36)
tau = 3

# need a 48 byte seed to get things started
seed = os.urandom(48)
bike = BIKE(seed)
bike.crypto_kem_keypair()
bike.crypto_kem_enc()
bike.crypto_kem_dec()

In [13]:
bike.ss

b'\x7f\xd3\xc7\\B?\xdb\xfaT a\xf4\x945\x18\xaci\x05\xc0"V\x90-s\xeeq0a4\xe8\xc3\xe6'

In [None]:
# original implementation

# Decoder helpers

# transposes the row to a column
def getCol(row: list) -> list:
        col = [None] * (w // 2)
        if row[0] == 0:
            col[0] = 0
            for i in range(1, w//2):
                col[i] = r - row[w//2 - i]
        else:
            for i in range(w/2):
                col[i] = r - row[w//2 - 1 - i]
        return col

# counts the number of nonzero positions
def getHammingWeight(s: list) -> int:
    return s.count(1)    

# number of unsatisfied parity checks
# slowest operation (this is effectively called r * w times)
def counter(h_col: list, pos: int, syn: list) -> int:
    cn = 0
    for i in range(w // 2):
        if syn[(h_col[i] + pos) % r]:
            cn += 1
    return cn

# BGF Decoder
class BGF:
    def __init__(self, syndrome, h0_compact, h1_compact):
        self.syndrome = syndrome
        self.h0_compact = h0_compact
        self.h1_compact = h1_compact
        self.h0_compact_col = getCol(h0_compact)
        self.h1_compact_col = getCol(h1_compact)
        
        self.e = [0] * r * 2
        self.black = [0] * r * 2
        self.gray  = [0] * r * 2
    
    def flipAdjustedErrorPos(self, pos: int):
        apos = pos
        if pos != 0 and pos != r:
            if pos > r:
                apos = (2 * r - pos) + r
            else:
                apos = r - pos
        self.e[apos] ^^= 1
  
    def recompute_syndrome(self, pos: int):
        if pos < r:
            for j in range(w // 2):
                val = self.h0_compact[j]
                if val <= pos:
                    self.syndrome[pos - val] ^^= 1
                else:
                    self.syndrome[r - val + pos] ^^= 1
        else:
            for j in range(w // 2):
                val = self.h1_compact[j]
                if val <= (pos - r):
                    self.syndrome[(pos - r) - val] ^^= 1
                else:
                    self.syndrome[r - val + (pos - r)] ^^= 1

    def BFIter(self, T: int):
        pos = [0] * r * 2
        
        for j in range(r):
            cnt = counter(self.h0_compact_col, j, self.syndrome)
            if cnt >= T:
                self.flipAdjustedErrorPos(j)
                pos[j] = 1
                self.black[j] = 1
            elif cnt >= (T - tau):
                self.gray[j] = 1
        
        for j in range(r):
            cnt = counter(self.h1_compact_col, j, self.syndrome)
            if cnt >= T:
                self.flipAdjustedErrorPos(j + r)
                pos[j + r] = 1
                self.black[j + r] = 1
            elif cnt >= (T - tau):
                self.gray[j + r] = 1
        
        for j in range(r * 2):
            if pos[j] == 1:
                self.recompute_syndrome(j)
                
    def BFMaskedIter(self, T: int, mask: list):
        pos = [0] * r *2
        
        for j in range(r):
            cnt = counter(self.h0_compact_col, j, self.syndrome)
            if cnt >= T and mask[j]:
                self.flipAdjustedErrorPos(j)
                pos[j] = 1
        
        for j in range(r):
            cnt = counter(self.h1_compact_col, j, self.syndrome)
            if cnt >= T and mask[j + r]:
                self.flipAdjustedErrorPos(j + r)
                pos[j + r] = 1
        
        for j in range(r * 2):
            if pos[j] == 1:
                self.recompute_syndrome(j)

    def Decode(self):
        for i in range(NbIter):
            self.black = [0] * r * 2
            self.gray  = [0] * r * 2
            
            T = floor(threshold(getHammingWeight(self.syndrome)))
            self.BFIter(T)
            
            if i == 0:
                self.BFMaskedIter(ceil(w/4) + 1, self.black)
                self.BFMaskedIter(ceil(w/4) + 1, self.gray)
        
        return getHammingWeight(self.syndrome)