In [None]:
from sage.all import *      # Tudo o que precisas para polinómios, vetores, GF(2), etc.
import hashlib
import os
from random import randrange, sample

In [None]:
# Segurança nível 1
R_BITS = 12323
N_BITS = 2 * R_BITS
DV = 71         # peso das chaves privadas
T1 = 134        # peso do vetor de erro
ELL_BITS = 256  # tamanho do segredo partilhado (em bits)
ELL_SIZE = ELL_BITS // 8

# Parâmetros do decoder BGF
tau = 7 # threshold inicial para decidir quais bits inverter
NbIter = 15 # número máximo de iterações do algoritmo

def divide_and_ceil(x, d): return (x + d - 1) // d

R_SIZE = divide_and_ceil(R_BITS, 8)
N_SIZE = divide_and_ceil(N_BITS, 8)

In [None]:
F = GF(2)
R = PolynomialRing(F, 'x')
x = R.gen()
mod_poly = x**R_BITS - 1

In [None]:
class SecretKey:
    def __init__(self, h0, h1, sigma, e_orig=None):
        """
        h0, h1: polinómios em F2[x]/(x^R_BITS + 1), peso DV
        sigma: vetor de ELL_BITS bits (seed)
        e_orig: vetor de erro original (para fins de teste)
        """
        self.h0 = h0
        self.h1 = h1
        self.sigma = sigma
        self.e_orig = e_orig  # Armazenar o vetor de erro original para o decoder

In [None]:
class PublicKey:
    def __init__(self, h):
        """
        h: polinómio em F2[x]/(x^R_BITS + 1), resultado h1 * (h0)^(-1)
        """
        self.h = h

In [None]:
class Ciphertext:
    def __init__(self, c0, c1, e_poly=None):
        """
        c0: polinómio em F2[x]/(x^R_BITS + 1)
        c1: vetor de ELL_BITS bits (bytes ou Vector(F, ...))
        e_poly: vetor de erro original (para fins de teste)
        """
        self.c0 = c0
        self.c1 = c1
        self.e_poly = e_poly # Armazenar o vetor de erro para o decoder

In [None]:
class SharedSecret:
    def __init__(self, raw):
        """
        raw: vetor de ELL_BITS bits (Vector(F, ...) ou bytes)
        """
        self.raw = raw

In [None]:
def debug(msg):
    print(f"[DEBUG] {msg}")

In [None]:
def generate_sparse_vector(length, weight):
    """Gera vetor esparso com peso exato usando métodos nativos do Sage"""
    # Usar Integer em vez de int para garantir compatibilidade
    positions = sorted(sample(range(Integer(length)), Integer(weight)))
    
    # Usar Integer nas posições dos expoentes
    poly = sum(x**Integer(pos) for pos in positions)
    poly = poly % mod_poly
    actual_weight = poly.hamming_weight()
    assert actual_weight == weight, f"Peso incorreto: {actual_weight} != {weight}"
    debug(f"Sparse vector (peso={weight}) gerado com posições: {positions[:5]}...")
    return poly

In [None]:
def invert_polynomial(poly):
    """Calcula inverso no anel F2[x]/(x^r - 1)"""
    try:
        inv = poly.inverse_mod(mod_poly)
        debug("Inverso calculado com sucesso")
        return inv
    except Exception as e:
        raise ValueError(f"Falha ao inverter polinômio: {e}")

In [None]:
def multiply_polynomials(a, b):
    """Multiplicação no anel F2[x]/(x^r - 1) usando métodos otimizados"""
    if hasattr(a, 'quo_rem') and hasattr(b, 'quo_rem'):
        # Usar quo_rem para divisão mais eficiente
        res = (a * b).quo_rem(mod_poly)[1]  # Resto da divisão
    else:
        # Fallback para método tradicional
        res = (a * b) % mod_poly
    return res

In [None]:
def split_vector(e_poly):
    """Divide polinómio de grau < 2R em (e0, e1) corretamente"""
    coeffs = e_poly.list()
    if len(coeffs) < 2*R_BITS:
        coeffs.extend([0] * (2*R_BITS - len(coeffs)))
    
    # e0: primeiros R_BITS coeficientes
    e0_coeffs = coeffs[:R_BITS]
    e0 = sum(F(e0_coeffs[i]) * x**Integer(i) for i in range(len(e0_coeffs)) if e0_coeffs[i] != 0)
    e0 = e0 % mod_poly
    
    # e1: próximos R_BITS coeficientes
    e1_coeffs = coeffs[R_BITS:2*R_BITS]
    e1 = sum(F(e1_coeffs[i]) * x**Integer(i) for i in range(len(e1_coeffs)) if e1_coeffs[i] != 0)
    e1 = e1 % mod_poly
    
    debug(f"Split: e0 peso={e0.hamming_weight()}, e1 peso={e1.hamming_weight()}")
    return e0, e1

In [None]:
def xor_vectors(a, b):
    """XOR entre vetores binários"""
    if hasattr(a, 'list') and hasattr(b, 'list'):
        return vector(F, [a[i] + b[i] for i in range(len(a))])
    else:
        return vector(F, [F(a[i]) + F(b[i]) for i in range(len(a))])

In [None]:
def bytes_to_vector(data, target_bits):
    """Converte bytes para vetor de bits"""
    bits = []
    for byte in data:
        for i in range(8):
            if len(bits) >= target_bits:
                break
            bits.append((byte >> i) & 1)
        if len(bits) >= target_bits:
            break
    return vector(F, bits[:target_bits])

In [None]:
def vector_to_bytes(vec):
    """Converte vetor de bits para bytes"""
    data = bytearray((len(vec) + 7) // 8)
    for i, bit in enumerate(vec):
        # Usar Integer para garantir tipos corretos
        if Integer(bit) == 1:
            data[i // 8] |= (1 << (i % 8))
    return bytes(data)

In [None]:
def poly_to_bytes(poly, target_bits):
    """Converte polinómio para bytes"""
    coeffs = poly.list()
    if len(coeffs) < target_bits:
        coeffs.extend([0] * (target_bits - len(coeffs)))
    vec = vector(F, coeffs[:target_bits])
    return vector_to_bytes(vec)

In [None]:
def functionH(m_vec):
    """H gera vetor esparso de peso T1 usando SHAKE256"""
    m_bytes = vector_to_bytes(m_vec)
    positions = set()
    counter = Integer(0)
    
    while len(positions) < T1:
        shake = hashlib.shake_256()
        counter_bytes = int(counter).to_bytes(4, 'little')
        shake.update(m_bytes + counter_bytes)
        counter += 1
        
        rand_bytes = shake.digest(4)
        rand_int = int.from_bytes(rand_bytes, 'little')
        pos = Integer(rand_int % N_BITS)
        positions.add(pos)
    
    positions = sorted(list(positions))[:T1]
    poly = sum(x**pos for pos in positions)
    poly = poly % (x**N_BITS - 1)
    
    debug(f"functionH: gerado vetor peso {len(positions)}")
    return poly

In [None]:
def functionL(e_poly):
    """L: calcula hash do vetor de erro"""
    e_bytes = poly_to_bytes(e_poly, N_BITS)
    h = hashlib.sha3_384(e_bytes).digest()
    truncated = h[:ELL_SIZE]
    return bytes_to_vector(truncated, ELL_BITS)

In [None]:
def functionK(m_vec, c0_poly, c1_vec):
    """K: deriva chave final"""
    m_bytes = vector_to_bytes(m_vec)
    c0_bytes = poly_to_bytes(c0_poly, R_BITS)
    c1_bytes = vector_to_bytes(c1_vec)
    
    inp = m_bytes + c0_bytes + c1_bytes
    h = hashlib.sha3_384(inp).digest()
    truncated = h[:ELL_SIZE]
    return bytes_to_vector(truncated, ELL_BITS)

In [None]:
def create_circulant_matrix(poly):
    """Cria matriz circulante a partir de polinômio"""
    coeffs = poly.list()
    if len(coeffs) < R_BITS:
        coeffs.extend([0] * (R_BITS - len(coeffs)))
    
    # Usar Matrix do SageMath
    circulant = Matrix(F, R_BITS, R_BITS)
    
    for i in range(R_BITS):
        for j in range(R_BITS):
            idx = (j - i) % R_BITS
            circulant[i, j] = F(coeffs[idx])
    
    return circulant

In [None]:
def compute_syndrome(c0_poly, h0_poly):
    """Calcula síndrome corretamente"""
    # Síndrome = c0 * h0 mod (x^r - 1)
    prod = multiply_polynomials(c0_poly, h0_poly)
    coeffs = prod.list()
    if len(coeffs) < R_BITS:
        coeffs.extend([0] * (R_BITS - len(coeffs)))
    return vector(F, coeffs[:R_BITS])

In [None]:
def BGF_decoder_simplified(syndrome_vec, h0_poly, h1_poly):
    """
    Decoder BGF simplificado e corrigido
    """
    debug("BGF_decoder: iniciando decodificação simplificada")
    
    # Criar matrizes circulantes
    H0 = create_circulant_matrix(h0_poly)
    H1 = create_circulant_matrix(h1_poly)
    
    # Inicializar vetores de erro
    e0 = vector(F, R_BITS)
    e1 = vector(F, R_BITS)
    
    s = syndrome_vec
    syndrome_weight = sum(1 for bit in s if bit == 1)
    debug(f"Peso da síndrome inicial: {syndrome_weight}")
    
    if syndrome_weight == 0:
        debug("Síndrome inicial é zero")
        return R(0), True
    
    # Parâmetros do decoder otimizados
    max_iter = 20
    threshold = 5  # Threshold inicial
    
    for iter_count in range(max_iter):
        debug(f"Iteração {iter_count+1}/{max_iter}, threshold={threshold}")
        
        # Calcular UPC (Unsatisfied Parity Checks)
        upc0 = vector(ZZ, R_BITS)
        upc1 = vector(ZZ, R_BITS)
        
        # Otimização: só processar onde síndrome é não-zero
        for j in range(len(s)):
            if s[j] == 1:
                for i in range(R_BITS):
                    if H0[j,i] == 1:
                        upc0[i] += 1
                    if H1[j,i] == 1:
                        upc1[i] += 1
        
        # Encontrar bits para inverter
        flip_bits_0 = [i for i in range(R_BITS) if upc0[i] >= threshold]
        flip_bits_1 = [i for i in range(R_BITS) if upc1[i] >= threshold]
        
        debug(f"Bits para inverter: {len(flip_bits_0)} em e0, {len(flip_bits_1)} em e1")
        
        # Limitar número de inversões
        max_flips = min(T1, 50)  # Limite conservativo
        if len(flip_bits_0) > max_flips:
            flip_bits_0 = sorted(flip_bits_0, key=lambda i: upc0[i], reverse=True)[:max_flips]
        if len(flip_bits_1) > max_flips:
            flip_bits_1 = sorted(flip_bits_1, key=lambda i: upc1[i], reverse=True)[:max_flips]
        
        # Inverter bits
        changed = False
        for i in flip_bits_0:
            e0[i] = 1 - e0[i]
            changed = True
        for i in flip_bits_1:
            e1[i] = 1 - e1[i]
            changed = True
        
        if not changed:
            # Reduzir threshold se não houve mudanças
            threshold = max(1, threshold - 1)
            debug(f"Reduzindo threshold para {threshold}")
            continue
        
        # Recalcular síndrome
        e0_poly = sum(e0[i] * x**Integer(i) for i in range(R_BITS) if e0[i] == 1)
        e1_poly = sum(e1[i] * x**Integer(i) for i in range(R_BITS) if e1[i] == 1)
        
        # CORREÇÃO: Síndrome correta para BIKE
        # s = (e0 + e1*h) * h0 = e0*h0 + e1*h*h0 = e0*h0 + e1*h1
        new_syndrome = vector(F, R_BITS)
        
        # Calcular e0*h0
        prod1 = multiply_polynomials(e0_poly, h0_poly)
        coeffs1 = prod1.list()
        if len(coeffs1) < R_BITS:
            coeffs1.extend([0] * (R_BITS - len(coeffs1)))
        
        # Calcular e1*h1
        prod2 = multiply_polynomials(e1_poly, h1_poly)
        coeffs2 = prod2.list()
        if len(coeffs2) < R_BITS:
            coeffs2.extend([0] * (R_BITS - len(coeffs2)))
        
        # Somar (XOR)
        for i in range(R_BITS):
            new_syndrome[i] = F(coeffs1[i]) + F(coeffs2[i])
        
        new_syndrome_weight = sum(1 for bit in new_syndrome if bit == 1)
        debug(f"Peso da síndrome após iteração: {new_syndrome_weight}")
        
        if new_syndrome_weight == 0:
            debug(f"Decodificação bem-sucedida na iteração {iter_count+1}")
            break
        
        s = new_syndrome
        
        # Ajustar threshold dinamicamente
        if iter_count > max_iter // 2:
            threshold = max(1, threshold - 1)
    
    # Construir polinómio de erro final
    e0_poly = sum(e0[i] * x**Integer(i) for i in range(R_BITS) if e0[i] == 1)
    e1_poly = sum(e1[i] * x**Integer(i) for i in range(R_BITS) if e1[i] == 1)
    
    # Verificar se decodificação foi bem-sucedida
    final_syndrome_weight = sum(1 for bit in s if bit == 1)
    success = final_syndrome_weight == 0
    
    if not success:
        debug("Falha na decodificação")
        return None, False
    
    # CORREÇÃO: Construir polinómio de erro completo corretamente
    e_poly = R(0)
    
    # Parte e0 (graus 0 a R_BITS-1)
    e_poly += e0_poly
    
    # Parte e1 (graus R_BITS a 2*R_BITS-1)
    for i in range(R_BITS):
        if e1[i] == 1:
            e_poly += x**(Integer(i) + R_BITS)
    
    e_poly = e_poly % (x**N_BITS - 1)
    
    e_weight = e_poly.hamming_weight()
    debug(f"Decodificação bem-sucedida, peso do erro={e_weight}")
    
    return e_poly, True

In [None]:
def generate_keypair():
    debug("=== Início keypair ===")
    
    h0 = generate_sparse_vector(R_BITS, DV)
    h1 = generate_sparse_vector(R_BITS, DV)
    
    # USAR random_vector em vez de list comprehension!
    sigma = random_vector(F, ELL_BITS)
    
    h0_inv = invert_polynomial(h0)
    h = multiply_polynomials(h1, h0_inv)
    
    sk = SecretKey(h0, h1, sigma)
    pk = PublicKey(h)
    
    debug("=== Fim keypair ===")
    return pk, sk

In [None]:
def encaps(pk):
    debug("=== Início encaps ===")
    
    # Usar random_vector em vez de list comprehension
    m = random_vector(F, ELL_BITS)
    
    e_poly = functionH(m)
    e0_poly, e1_poly = split_vector(e_poly)
    
    temp = multiply_polynomials(e1_poly, pk.h)
    c0 = (e0_poly + temp) % mod_poly
    
    L_vec = functionL(e_poly)
    c1 = xor_vectors(L_vec, m)
    
    ss = functionK(m, c0, c1)
    
    debug("=== Fim encaps ===")
    return Ciphertext(c0, c1, e_poly), SharedSecret(ss)

In [None]:
def decaps(ct, sk):
    debug("=== Início decaps ===")
    
    #Síndrome = c0 * h0
    syndrome = compute_syndrome(ct.c0, sk.h0)
    
    # Usar decoder simplificado
    e_prime_poly, decoder_success = BGF_decoder_simplified(syndrome, sk.h0, sk.h1)
    
    if not decoder_success:
        debug("Decoder falhou, usando sigma como fallback")
        return SharedSecret(functionK(sk.sigma, ct.c0, ct.c1))
    
    L_prime = functionL(e_prime_poly)
    m_prime = xor_vectors(ct.c1, L_prime)
    
    e_recomputed_poly = functionH(m_prime)
    
    # Verificar se os vetores de erro coincidem
    e_prime_coeffs = e_prime_poly.list()
    e_recomputed_coeffs = e_recomputed_poly.list()
    
    max_len = max(len(e_prime_coeffs), len(e_recomputed_coeffs))
    while len(e_prime_coeffs) < max_len:
        e_prime_coeffs.append(0)
    while len(e_recomputed_coeffs) < max_len:
        e_recomputed_coeffs.append(0)
    
    if e_prime_coeffs == e_recomputed_coeffs:
        debug("Sucesso: vetores de erro coincidem")
        k = functionK(m_prime, ct.c0, ct.c1)
    else:
        debug("Falha: usando sigma como fallback")
        k = functionK(sk.sigma, ct.c0, ct.c1)
    
    debug("=== Fim decaps ===")
    return SharedSecret(k)

In [None]:
num_tests = 5
success = 0

for i in range(num_tests):
    print(f"\n=== Teste {i+1}/{num_tests} ===")
    try:
        pk, sk = generate_keypair()
        ct, ss_enc = encaps(pk)
        
        # Teste sem acesso ao vetor original
        ct_test = Ciphertext(ct.c0, ct.c1, None)
        ss_dec = decaps(ct_test, sk)
        
        if vector_to_bytes(ss_enc.raw) == vector_to_bytes(ss_dec.raw):
            success += 1
            print(f"✅ Teste {i+1}: Sucesso")
        else:
            print(f"❌ Teste {i+1}: Falha")
    except Exception as e:
        print(f"❌ Teste {i+1}: Erro - {e}")

taxa = success / num_tests * 100
print(f"\nTaxa de sucesso: {taxa:.2f}% ({success}/{num_tests})")