# Pregunta 2: Funci√≥n de Hash MD5

## Utils

In [None]:
from math import sin, floor

def bit_not(n, numbits=32):
    return (1 << numbits) - 1 - n

def left_rotate(n, c, numbits=32):
    n = n & bit_not(0,numbits)
    return ((n << c) | (n >> (numbits - c))) & bit_not(0, numbits)

def gen_shift_table():
    return sum([
        [7, 12, 17, 22] * 4,
        [5, 9, 14, 20] * 4,
        [4, 11, 16 ,23] * 4,
        [6, 10, 15, 21] * 4
    ], [])

def gen_constants_table():
    return [floor(2**32 * abs(sin(i + 1))) for i in range(64)]


## `custom_md5` implementation

In [None]:
def custom_md5(m: str, h0: int = 137269462086865085541390238039692956790) -> str:
    """
    Arguments:
        m:  str - message
        h0: int - initial constant H_0 (< 2**128)
    Returns:
        str - correct MD5 hash of m in hexadecimal format (digest)
    """
    # Constants (in bytes)
    DIGEST_SIZE = 128 // 8
    BLOCK_SIZE = 512 // 8
    LEN_LENGTH = 64 // 8

    s = gen_shift_table()
    K = gen_constants_table()

    # Starting buffer values (note: bit_not(0) returns 0xFFFFFFFF)
    A0, B0, C0, D0 = [(h0 >> 32*(3-i)) & bit_not(0) for i in range(4)]

    # Padding
    m_bits = bytearray(m, encoding='utf-8')
    original_length = len(m_bits) * 8 # in bits
    m_bits += b'\x80'
    while len(m_bits) % (BLOCK_SIZE) < (BLOCK_SIZE - LEN_LENGTH):
        m_bits += b'\x00'
    # Asume message length < 2**64
    m_bits += original_length.to_bytes(LEN_LENGTH, 'little')

    for i in range(0, len(m_bits), BLOCK_SIZE):
        M = [m_bits[i:i+BLOCK_SIZE][k:k+4] for k in range(0, BLOCK_SIZE, 4)]
        A, B, C, D = A0, B0, C0, D0
        for j in range(64):
            if j < 16:
                F = (B & C) | (bit_not(B) & D)
                g = j
            elif j < 32:
                F = (D & B) | (bit_not(D) & C)
                g = (5 * j + 1) % 16
            elif j < 48:
                F = B ^ C ^ D
                g = (3 * j + 5) % 16
            elif j < 64:
                F = C ^ (B | bit_not(D))
                g = (7 * j) % 16
            F = F + A + K[j] + int.from_bytes(M[g], 'little')
            A = D
            D = C
            C = B
            B = (B + left_rotate(F, s[j])) & bit_not(0)
        A0, B0, C0, D0 = (A0 + A) & bit_not(0), (B0 + B) & bit_not(0), (C0 + C) & bit_not(0), (D0 + D) & bit_not(0)

    digest = (D0 << 96) + (C0 << 64) + (B0 << 32) + A0
    return digest.to_bytes(DIGEST_SIZE, 'little').hex()


## Testing

In [None]:
if __name__ == '__main__':
    from hashlib import md5
    example_string = "The quick brown fox jumps over the lazy dog."*100
    md5_output = md5(example_string.encode()).hexdigest()
    custom_md5_output = custom_md5(example_string)
    print("REAL MD5:", md5_output)
    print("CUSTOM MD5:", custom_md5_output)
    assert(md5_output == custom_md5_output)
