# MD4
** Implementation of the RFC-1320 MD4 Algorithm **

*... in Python*

This implementation of MD4 is for demonstrattion purposes to understand and see the MD4 algorithm in action with all intermediate steps.

A graphical representation of the algorithm is as-

![MD4](https://upload.wikimedia.org/wikipedia/commons/thumb/1/1a/MD4.svg/300px-MD4.svg.png)
Image from Wikipedia

### Define and Select Test Cases

In [1]:
test_case=[["","31d6cfe0d16ae931b73c59d7e0c089c0"],\
           ["a","bde52cb31de33e46245e05fbdbd6fb24"],\
           ["abc","a448017aaf21d8525fc10ae87aa6729d"],\
           ["message digest","d9130a8164549fe818874806e1c7014b"],\
           ["abcdefghijklmnopqrstuvwxyz","d79e1c308aa5bbcdeea8ed63df412da9"],\
           ["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789","043f8582f241db351ce627e153e7f0e4"],\
           ["12345678901234567890123456789012345678901234567890123456789012345678901234567890","e33b4ddc9c38f2199c3e7b164fcc0536"]]
use_test_case = 1
##
message = test_case[use_test_case][0]
ref_hash = test_case[use_test_case][1]

### Step 1 - Append Padding Bits

The messsage to be hashed is padded to have a length equal to 8 bytes {64 bits} less than being a multiple of 64 bytes {512 bits}. The padding step is performed even if the message length is already of desired length. The padding bit string used is `1` followed by `0` - `100...000`

The message length is eventually 56 bytes {448 bits}, 120 bytes {960 bits}, 184 bytes {1472 bits}, 248 bytes {1984 bits} and so on.

In [2]:
message_len = len(message)
message_len_bits = message_len * 8
print("Message Length : " + str(message_len) + " bytes {" + str(message_len_bits) + " bits}")

Message Length : 1 bytes {8 bits}


In [3]:
# Encode string to bytes
message_b = message.encode('utf-8')

In [4]:
# Calculate padding length
padding_len=56-message_len%64
padding_len=64 if (padding_len==0) else padding_len
print("Padding Length : " + str(padding_len) + " bytes {" + str(padding_len * 8) + " bits}")

Padding Length : 55 bytes {440 bits}


In [5]:
# Display Padded Message, length and calculation.
message_mod448 = message_b + b'\x80' + b'\x00' * (padding_len-1)
print("Padded Message :\n"+str(message_mod448))
print("\nlength(paddedMessage)      : "+str(len(message_mod448))+" bytes {"+str(len(message_mod448*8))+" bits}\nlength(paddedMessage) % 64 : "+str(len(message_mod448)%64)+" bytes {"+str((len(message_mod448)%64) * 8)+" bits}" )

Padded Message :
b'a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'

length(paddedMessage)      : 56 bytes {448 bits}
length(paddedMessage) % 64 : 56 bytes {448 bits}


### Step 2 - Append Length

The bit length of the original message is appened to this _64 bits short of %512 bit_ message. This bit length is appeneded as an 8 byte {64 bits} little endian integer.

So, a message of length 14 bytes (_try test case # 3_) would have a bit length of 112 bits and the appended 64 bit little endian bit length would be `0x7000000000000000` (as hex) or `b'p\x00\x00\x00\x00\x00\x00\x00'` (as a byte string). If the message length is $> 2^{64}$ bits, only the lower 64 bits are used for padding.

In [6]:
# Append Length
processed_message=message_mod448+(message_len_bits%2**64).to_bytes(8,byteorder='little')
print("LSB64(len(unPaddedMessage)) : "+str((message_len_bits%2**64).to_bytes(8,byteorder='little')))
print("length( paddedMessage | LSB64(len(unPaddedMessage)) ) : "+str(len(processed_message))+" bytes {"+str(len(processed_message)*8)+" bits}")
print("\nPadded Message | LSB64(len(unPaddedMessage)) :\n"+str(processed_message))

LSB64(len(unPaddedMessage)) : b'\x08\x00\x00\x00\x00\x00\x00\x00'
length( paddedMessage | LSB64(len(unPaddedMessage)) ) : 64 bytes {512 bits}

Padded Message | LSB64(len(unPaddedMessage)) :
b'a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00'


### Step 3 - Initilize MD Buffer

In [7]:
A = 0x67452301
B = 0xEFCDAB89
C = 0x98BADCFE
D = 0x10325476

### Step 4 - Process Message in 16-Word Blocks

In [8]:
# Auxulary functions that take in 3x 32bit words and return 1x32bit word.

def F(X, Y, Z):
    return ((X&Y) | ((~X) & Z))

def G(X, Y, Z):
    return ((X&Y) | (X&Z) |(Y&Z))

def H(X, Y, Z):
    return (X^Y^Z)

In [9]:
# Rotate Left
def rotl(x,s):
    return ( (x<<s) | x>>(32-s))

In [10]:
# Shift Table
Round_shifts=[3,7,11,19]*4+[3,5,9,13]*4+[3,9,11,15]*4
print(Round_shifts)

# K table (to use a sub-string of the message)
R1_k = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
R2_k = [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
R3_k = [0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15]

[3, 7, 11, 19, 3, 7, 11, 19, 3, 7, 11, 19, 3, 7, 11, 19, 3, 5, 9, 13, 3, 5, 9, 13, 3, 5, 9, 13, 3, 5, 9, 13, 3, 9, 11, 15, 3, 9, 11, 15, 3, 9, 11, 15, 3, 9, 11, 15]


In [11]:

def bytereverse(num32):
    rev_byte=0;
    for i in range(0,16):
        #print(hex(num32)+" "+hex(rev_byte))
        rev_byte = rev_byte << 8
        
        low_order_byte = num32 & 0xFF
        rev_byte = rev_byte | low_order_byte
        
        num32 = num32 >> 8
    return rev_byte

In [12]:
## Round Functions
def round1(a, b, c, d, X, k, s):
    Xk = int.from_bytes(X[4*k:4*k+4],byteorder='little')
    #print(hex(F(b,c,d)))
    FN = F(b,c,d)
    a = (a + FN + Xk) & 0xFFFFFFFF
    a = ((rotl(a , s)& 0xFFFFFFFF))
    round_output(1,k,s,FN)
    return a 

def round2(a, b, c, d, X, k, s):
    Xk = int.from_bytes(X[4*k:4*k+4],byteorder='little')
    #print(hex(G(b,c,d)))
    FN = G(b,c,d)
    a = (a + FN + Xk + 0x5A827999) & 0xFFFFFFFF
    a = ((rotl(a , s)& 0xFFFFFFFF))
    round_output(2,k,s,FN)
    return a 

def round3(a, b, c, d, X, k, s):
    Xk = int.from_bytes(X[4*k:4*k+4],byteorder='little')
    #print(hex(H(b,c,d)))
    FN = H(b,c,d)
    a = (a + FN + Xk + 0x6ED9EBA1) & 0xFFFFFFFF
    a = ((rotl(a , s)& 0xFFFFFFFF))
    round_output(3,k,s,FN)
    return a

def round_output(R,k,s,FN):
    print("     R"+str(R)+" | K = "+"{:2d}".format(k)+" | s = "+"{:2d}".format(s)+" | {:9x}".format(FN))
    return None

In [13]:
# Loop though the various 512 bit blocks of a long message.
for i in range(0,len(processed_message),64):
    print("PROCESSING bytes "+str(i)+"..."+str(i+64))
    X  = processed_message[i:i+64]
    print("\nMessage chunk being processed :\n"+str(X)+" \n")
    AA = A
    BB = B
    CC = C
    DD = D
    
    #round 1
    print("*** ROUND 1 ***")
    A = round1(A, B, C, D, X, R1_k[0], Round_shifts[0])
    D = round1(D, A, B, C, X, R1_k[1], Round_shifts[1])
    C = round1(C, D, A, B, X, R1_k[2], Round_shifts[2])
    B = round1(B, C, D, A, X, R1_k[3], Round_shifts[3])
    
    A = round1(A, B, C, D, X, R1_k[4], Round_shifts[4])
    D = round1(D, A, B, C, X, R1_k[5], Round_shifts[5])
    C = round1(C, D, A, B, X, R1_k[6], Round_shifts[6])
    B = round1(B, C, D, A, X, R1_k[7], Round_shifts[7])
    
    A = round1(A, B, C, D, X, R1_k[8], Round_shifts[8])
    D = round1(D, A, B, C, X, R1_k[9], Round_shifts[9])
    C = round1(C, D, A, B, X, R1_k[10], Round_shifts[10])
    B = round1(B, C, D, A, X, R1_k[11], Round_shifts[11])
    
    A = round1(A, B, C, D, X, R1_k[12], Round_shifts[12])
    D = round1(D, A, B, C, X, R1_k[13], Round_shifts[13])
    C = round1(C, D, A, B, X, R1_k[14], Round_shifts[14])
    B = round1(B, C, D, A, X, R1_k[15], Round_shifts[15])
    
    print("\n*** ROUND 2 ***")
    A = round2(A, B, C, D, X, R2_k[0], Round_shifts[16])
    D = round2(D, A, B, C, X, R2_k[1], Round_shifts[17])
    C = round2(C, D, A, B, X, R2_k[2], Round_shifts[18])
    B = round2(B, C, D, A, X, R2_k[3], Round_shifts[19])
    
    A = round2(A, B, C, D, X, R2_k[4], Round_shifts[20])
    D = round2(D, A, B, C, X, R2_k[5], Round_shifts[21])
    C = round2(C, D, A, B, X, R2_k[6], Round_shifts[22])
    B = round2(B, C, D, A, X, R2_k[7], Round_shifts[23])
    
    A = round2(A, B, C, D, X, R2_k[8], Round_shifts[24])
    D = round2(D, A, B, C, X, R2_k[9], Round_shifts[25])
    C = round2(C, D, A, B, X, R2_k[10], Round_shifts[26])
    B = round2(B, C, D, A, X, R2_k[11], Round_shifts[27])
    
    A = round2(A, B, C, D, X, R2_k[12], Round_shifts[28])
    D = round2(D, A, B, C, X, R2_k[13], Round_shifts[29])
    C = round2(C, D, A, B, X, R2_k[14], Round_shifts[30])
    B = round2(B, C, D, A, X, R2_k[15], Round_shifts[31])
    
    print("\n*** ROUND 3 ***")
    A = round3(A, B, C, D, X, R3_k[0], Round_shifts[32])
    D = round3(D, A, B, C, X, R3_k[1], Round_shifts[33])
    C = round3(C, D, A, B, X, R3_k[2], Round_shifts[34])
    B = round3(B, C, D, A, X, R3_k[3], Round_shifts[35])
    
    A = round3(A, B, C, D, X, R3_k[4], Round_shifts[36])
    D = round3(D, A, B, C, X, R3_k[5], Round_shifts[37])
    C = round3(C, D, A, B, X, R3_k[6], Round_shifts[38])
    B = round3(B, C, D, A, X, R3_k[7], Round_shifts[39])
    
    A = round3(A, B, C, D, X, R3_k[8], Round_shifts[40])
    D = round3(D, A, B, C, X, R3_k[9], Round_shifts[41])
    C = round3(C, D, A, B, X, R3_k[10], Round_shifts[42])
    B = round3(B, C, D, A, X, R3_k[11], Round_shifts[43])
    
    A = round3(A, B, C, D, X, R3_k[12], Round_shifts[44])
    D = round3(D, A, B, C, X, R3_k[13], Round_shifts[45])
    C = round3(C, D, A, B, X, R3_k[14], Round_shifts[46])
    B = round3(B, C, D, A, X, R3_k[15], Round_shifts[47])
    
    # Update the MD buffer after processing each 512 bit block.
    A = (A + AA) & 0xFFFFFFFF
    B = (B + BB) & 0xFFFFFFFF
    C = (C + CC) & 0xFFFFFFFF
    D = (D + DD) & 0xFFFFFFFF
    
    # Display Updated MD Buffers
    print("\n*** MD Buffers after processing chunk ***\n[D C B A] = "+"[{:8x} {:8x} {:8x} {:8x}]".format(D,C,B,A)+"\n\n")

PROCESSING bytes 0...64

Message chunk being processed :
b'a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00' 

*** ROUND 1 ***
     R1 | K =  0 | s =  3 |  98badcfe
     R1 | K =  1 | s =  7 |  98bedffe
     R1 | K =  2 | s = 11 |  87458389
     R1 | K =  3 | s = 19 |      3a00
     R1 | K =  4 | s =  3 |  50943810
     R1 | K =  5 | s =  7 |   7457902
     R1 | K =  6 | s = 11 |  84c7dc42
     R1 | K =  7 | s = 19 |  e4c9c8be
     R1 | K =  8 | s =  3 |  e6b9233f
     R1 | K =  9 | s =  7 |  3968883e
     R1 | K = 10 | s = 11 |  19719e4a
     R1 | K = 11 | s = 19 |  83539e9b
     R1 | K = 12 | s =  3 |  99145bd4
     R1 | K = 13 | s =  7 |  7bb5c5e1
     R1 | K = 14 | s = 11 |  3f6dcfef
     R1 | K = 15 | s = 19 |  a7e3fe9f

*** ROUND 2 ***
     R2 | K =  0 | s =  3 |  2486388e
     

In [14]:
# Compute output hash from the MD buffers.
output_int = D<<96 | C <<64 | B << 32 | A

# The MD4 hash starts with the lowest order byte of A ... highest order byte of D
print("OUTPUT      : "+hex(bytereverse(output_int)))
print("REF. Hash   : 0x"+test_case[use_test_case][1])

OUTPUT      : 0xbde52cb31de33e46245e05fbdbd6fb24
REF. Hash   : 0xbde52cb31de33e46245e05fbdbd6fb24


## Compare with Python's `hashlib`

In [15]:
import hashlib

In [16]:
H = hashlib.new('md4')
H.update(message_b)
mdhash=H.hexdigest()
print("Hashlib MD4 : 0x"+md4hash)

Hashlib MD4 : 0xbde52cb31de33e46245e05fbdbd6fb24


### References

1. [RFC-1320](https://tools.ietf.org/html/rfc1320)
2. [Wikipedia](https://en.wikipedia.org/wiki/MD4)
3. [Rosetta Code](https://rosettacode.org/wiki/MD4#Python)