In [None]:
# helpers

from __future__ import annotations

import base64
import hmac
import math
import secrets
from typing import Any, Tuple

def random_bytes(n: int) -> bytes:
    if n <= 0:
        return ValueError("n debe ser positivo")
    return secrets.token_bytes(n)


def b64u_encode(b: bytes) -> str:
    """Base64 URL-safe sin padding '='."""
    import base64
    return base64.urlsafe_b64encode(b).decode("ascii").rstrip("=")

def b64u_decode(s: str) -> bytes:
    """Decodifica Base64 URL-safe sin padding, reponiendo '=' si falta."""
    import base64
    pad = "=" * (-len(s) % 4)
    return base64.urlsafe_b64decode(s + pad)

def consteq(a: bytes, b: bytes) -> bool:
    """Comparación en tiempo 'constante' (usa hmac.compare_digest)."""
    return hmac.compare_digest(a, b)

def ensure_bytes(x: Any, *, field: str) -> bytes:
    """Convierte str→utf-8; permite bytes; error para otros tipos."""
    if isinstance(x, bytes):
        return x
    if isinstance(x, str):
        return x.encode("utf-8")
    raise TypeError(f"{field} debe ser str o bytes, no {type(x).__name__}")

## Implementacion de Salsa20/8

In [None]:
def _rotl32(x: int, n: int) -> int:
    # Rotacion circular de 4 bytes (32 bits)
    x &= 0xFFFFFFFF
    return ((x << n) | (x >> (32 - n))) & 0xFFFFFFFF

def _salsa20_8_core(block64: bytes) -> bytes:
    """
    Sobre un bloque de 64 bytes, retorna 64 bytes mezcladas
    """
    if len(block64) != 64:
        raise ValueError("Se espera un bloque de 64 bytes")

    # Se generan 16 palabras de 32 bits
    w = [int.from_bytes(block64[i * 4 : (i + 1) * 4])   for i in range(16)]
    # Copia
    x = w.copy()
    for _ in range(4): # 8 rondas o 4 rondas dobles
        # Rondas de columna
        x[ 4] ^= _rotl32((x[ 0] + x[12]) & 0xFFFFFFFF, 7)
        x[ 8] ^= _rotl32((x[ 4] + x[ 0]) & 0xFFFFFFFF, 9)
        x[12] ^= _rotl32((x[ 8] + x[ 4]) & 0xFFFFFFFF,13)
        x[ 0] ^= _rotl32((x[12] + x[ 8]) & 0xFFFFFFFF,18)

        x[ 9] ^= _rotl32((x[ 5] + x[ 1]) & 0xFFFFFFFF, 7)
        x[13] ^= _rotl32((x[ 9] + x[ 5]) & 0xFFFFFFFF, 9)
        x[ 1] ^= _rotl32((x[13] + x[ 9]) & 0xFFFFFFFF,13)
        x[ 5] ^= _rotl32((x[ 1] + x[13]) & 0xFFFFFFFF,18)

        x[14] ^= _rotl32((x[10] + x[ 6]) & 0xFFFFFFFF, 7)
        x[ 2] ^= _rotl32((x[14] + x[10]) & 0xFFFFFFFF, 9)
        x[ 6] ^= _rotl32((x[ 2] + x[14]) & 0xFFFFFFFF,13)
        x[10] ^= _rotl32((x[ 6] + x[ 2]) & 0xFFFFFFFF,18)

        x[ 3] ^= _rotl32((x[15] + x[11]) & 0xFFFFFFFF, 7)
        x[ 7] ^= _rotl32((x[ 3] + x[15]) & 0xFFFFFFFF, 9)
        x[11] ^= _rotl32((x[ 7] + x[ 3]) & 0xFFFFFFFF,13)
        x[15] ^= _rotl32((x[11] + x[ 7]) & 0xFFFFFFFF,18)

        # Rondas de fila
        x[ 1] ^= _rotl32((x[ 0] + x[ 3]) & 0xFFFFFFFF, 7)
        x[ 2] ^= _rotl32((x[ 1] + x[ 0]) & 0xFFFFFFFF, 9)
        x[ 3] ^= _rotl32((x[ 2] + x[ 1]) & 0xFFFFFFFF,13)
        x[ 0] ^= _rotl32((x[ 3] + x[ 2]) & 0xFFFFFFFF,18)

        x[ 6] ^= _rotl32((x[ 5] + x[ 4]) & 0xFFFFFFFF, 7)
        x[ 7] ^= _rotl32((x[ 6] + x[ 5]) & 0xFFFFFFFF, 9)
        x[ 4] ^= _rotl32((x[ 7] + x[ 6]) & 0xFFFFFFFF,13)
        x[ 5] ^= _rotl32((x[ 4] + x[ 7]) & 0xFFFFFFFF,18)

        x[11] ^= _rotl32((x[10] + x[ 9]) & 0xFFFFFFFF, 7)
        x[ 8] ^= _rotl32((x[11] + x[10]) & 0xFFFFFFFF, 9)
        x[ 9] ^= _rotl32((x[ 8] + x[11]) & 0xFFFFFFFF,13)
        x[10] ^= _rotl32((x[ 9] + x[ 8]) & 0xFFFFFFFF,18)

        x[12] ^= _rotl32((x[15] + x[14]) & 0xFFFFFFFF, 7)
        x[13] ^= _rotl32((x[12] + x[15]) & 0xFFFFFFFF, 9)
        x[14] ^= _rotl32((x[13] + x[12]) & 0xFFFFFFFF,13)
        x[15] ^= _rotl32((x[14] + x[13]) & 0xFFFFFFFF,18)

    out = [(x[i] + w[i]) & 0xFFFFFFFF for i in range(16)]

    # Unir en un mismo numero de 64 bytes
    return b"".join(o.to_bytes(4,"little") for o in out)




## Implementacion de BLockMix, ROMix

In [None]:
def _block_xor(a: bytes, b:bytes) -> bytes:
    return bytes(x ^ y for x, y in zip(a, b))

def _blockmix_salsa8(B: bytes, r: int) -> bytes:
    """
    BlockMix de scrypt usando Salsa20/8
    Entrada B, 128*r bytes (2r bloques de 64B).
    Salida: 128r bytes con reordenamiento (pares primero, luego impares)
    """

    if len(B) != 128 * r:
        raise ValueError("BlockMix espera longitud 128*r")
    X = B[(2 * r - 1) * 64 : (2 * r) * 64] # ultimo bloque de 64 bytes
    Y = [b""] * (2*r)
    for i in range(2 * r):
        Bi = B[i * 64 : (i + 1)*64]
        X = _salsa20_8_core(_block_xor(X, Bi))
        Y[i] = X
    out = bytearray(128 * r)
    pos = 0
    # Pares primero
    for i in range(r):
        out[pos:pos + 64] = Y[2 * i]
        pos += 64
    for i in range(r):
        out[pos:pos + 64] = Y[2 * i + 1]
        pos += 64
    return bytes(out)

def _integerify(B: bytes, r: int) -> int:
    if len(B) != 128 * r:
        raise ValueError("Integerify espera longitud de 128*r bytes")
    last = B[(2 * r - 1) * 64 : (2 * r) * 64]
    return int.from_bytes(last[0:8], "little")

def _romix(B: bytes, N: int, r: int) -> bytes:
    X = B
    V = [b""] * N
    for i in range(N):
        V[i] = X
        X = _blockmix_salsa8(X, r)
    for i in range(N):
        j = _integerify(X, r) % N
        X = _blockmix_salsa8(_block_xor(X, V[j]), r)
    return X

## Implamentacion de PBKDF2-HMAC-SHA256

In [None]:
import hashlib
def _prf_hmac_sha256(key: bytes, data: bytes) -> bytes:
    return hmac.new(key, data, hashlib.sha256).digest()

def pbkdf2_hmac_sha256(password: bytes, salt: bytes, iterations: int, dklen: int) -> bytes:
    assert iterations > 0
    if dklen <= 0:
        raise ValueError('dklen debe ser >= 1')
    hlen = 32
    l = math.ceil(dklen / hlen)
    dk = bytearray()
    for i in range(1, l + 1):
        U = _prf_hmac_sha256(password, salt + i.to_bytes(4, "big"))
        T = bytearray(U)
        for _ in range(2, iterations + 1):
            U = _prf_hmac_sha256(password, U)
            for j in range(hlen):
                T[j] ^= U[j]
        dk += T
    return bytes(dk[:dklen])

## Scrypt KDF

In [None]:
def scrypt_kdf(password: bytes, salt: bytes, N: int, r:int, p:int, dklen: int) -> bytes:
    """
    Deriva una clave con scrypt (RFC 7914), equivalente a hashlib.scrypt.
    - password, salt: bytes
    - N: potencia de 2 > 1
    - r, p: enteros positivos
    - dklen: longitud en bytes de la salida
    """
    if N <= 1 or (N & (N - 1)) != 0:
        raise ValueError("N debe ser > 1 y potencia de 2")
    if r <= 0 or p <= 0:
        raise ValueError("r y p deben ser positivos")
    if dklen <= 0:
        raise ValueError("dklen debe ser positivo")

    # Se usa PBKDF2-HMAC-SHA256 con 1 teraciones para expandir la clave a p * 128 * r bytes
    B = pbkdf2_hmac_sha256(password, salt, 1, p * 128 * r)
    # Partir en p bloques de 128 * r
    blocks = [B[i * 128 * r : (i + 1) * 128 * r] for i in range(p)]

    # ROMix a cada bloque
    out_parts = []
    for i in range(p):
        out_parts.append(_romix(blocks[i], N, r))
    Bp = b"".join(out_parts)

    # Se usa PBKDF2-HMAC-SHA256 con 1 iteraciones otra vez con dklen
    DK = pbkdf2_hmac_sha256(password, Bp, 1, dklen)
    return DK


## Almacenamiento

In [None]:
def encode_phc(ln :int, r: int, p: int, salt: bytes, digest: bytes) -> str:
    return f"$scrypt$ln={ln},r={r},p={p}${b64u_encode(salt)}${b64u_encode(digest)}"
def decode_phc(phc: str) -> Tuple[int, int, int, bytes, bytes]:
    if not phc.startswith("$scrypt$"):
        raise ValueError("No es un PHC-string scrypt")
    try:
        _, _, params, salt_b64, hash_b64 = phc.split("$", 4)
        parts = dict(kv.split("=", 1) for kv in params.split(","))
        ln = int(parts["ln"]); r = int(parts["r"]); p = int(parts["p"])
        salt = b64u_decode(salt_b64); digest = b64u_decode(hash_b64)
        return ln, r, p, salt, digest
    except Exception as e:
        raise ValueError("PHC-string invalido") from e

def scrypt_hash(password: bytes | str, *, N: int = 2** 14, r: int = 8, p: int = 1, dklen: int = 32, salt_len: int = 16) -> str:
    """
    Deriva y devuelve el PHC-string despues de aplicar scrypt
    """
    pwd = ensure_bytes(password, field="passowrd")
    if N <= 1 or (N & (N - 1)) != 0:
        raise ValueError("N debe ser potencia de 2")
    salt = secrets.token_bytes(salt_len)
    digest = scrypt_kdf(pwd, salt, N, r, p, dklen)
    ln = N.bit_length() - 1
    return encode_phc(ln, r, p, salt, digest)

def scrypt_verify(password: bytes | str, phc):
    # Verifica si una contraseña corresponde a un hash
    pwd = ensure_bytes(password, field = "password")
    ln, r, p, salt, digest = decode_phc(phc)
    N = 1 << ln
    derived = scrypt_kdf(pwd, salt, N, r, p, len(digest))
    return consteq(derived, digest)

def needs_rehash(phc: str, *, target_N: int, target_r: int, target_p: int, target_dklen: int) -> str:
    ln, r, p, _salt, digest = decode_phc(phc)
    N = 1 << ln
    return (N < target_N) or (r != target_r) or (p != target_p) or (len(digest) != target_dklen)



## Verificacion

In [None]:
phc = scrypt_hash("MiPasswordSegura!", N=2**12, r=8, p=1, dklen=32, salt_len=16)
print(phc)  # -> $scrypt$ln=12,r=8,p=1$...$...

assert scrypt_verify("MiPasswordSegura!", phc) is True
assert scrypt_verify("otra-cosa", phc) is False


In [None]:
import hashlib

P = b"password"
S = b"NaCl"

for N in (16, 32, 64):
    for r in (1, 2):
        for p in (1, 2):
            dklen = 64
            mine = scrypt_kdf(P, S, N, r, p, dklen)
            ref  = hashlib.scrypt(P, salt=S, n=N, r=r, p=p, dklen=dklen)
            print(N, r, p, mine == ref)
            assert mine == ref, (N, r, p)
print("OK")


16 1 1 False


AssertionError: (16, 1, 1)