In [1]:
import numpy as np

In [2]:
def ComputeLeftChild(tree, parent, computeCol=True):
    index = 2*parent+1
    if index >= len(tree):
        return
    
    delta = 2**int((int(np.log2(parent+1))/2)) # level = int(log2(index+1))     
    if computeCol:
        p = tree[parent].shape[0]
        tree[index] = np.zeros((p, p - delta), dtype=int)
        for i in xrange(0, p - delta):
            tree[index][:, i] = tree[parent][:, i] - tree[parent][:, i+delta]
            
    else:
        p = tree[parent].shape[1]
        tree[index] = np.zeros((p, p), dtype=int)
        for i in xrange(0, p):
            tree[index][i, :] = tree[parent][i, :] - tree[parent][i+delta, :]

    computeCol = not computeCol
    ComputeLeftChild(tree, index, computeCol)
    ComputeRightChild(tree, index, computeCol)
    return

In [3]:
def ComputeRightChild(tree, parent, computeCol=True):
    index = 2*parent+2
    if index >= len(tree):
        return
    
    delta = 2**int((int(np.log2(parent+1))/2)) # level = int(log2(index+1))
    if computeCol:
        p = tree[parent].shape[0]
        tree[index] = np.zeros((p, p - delta), dtype=int)
        for i in xrange(0, p - delta):
            tree[index][:, i] = tree[parent][:, i] + tree[parent][:, i+delta]
            
    else:
        p = tree[parent].shape[1]
        tree[index] = np.zeros((p, p), dtype=int)
        for i in xrange(0, p):            
            tree[index][i, :] = tree[parent][i, :] + tree[parent][i+delta, :] 
    
    computeCol = not computeCol
    ComputeLeftChild(tree, index, computeCol)
    ComputeRightChild(tree, index, computeCol)
    return

In [4]:
# Construct WHT compute tree
def WHTTree(input):
    # Check if input is valid
    m = np.log2(input.shape[0])
    if not m.is_integer:
        print("input shape is not power of 2!")
        return
    if not m == np.log2(input.shape[1]):
        print("input is not square matrix!")
        return
    
    h = 2*m # Levels of binary tree
    length = int(2**(h+1)-1) # Total length of full binary tree
    tree = np.empty(length, dtype=object)        
    tree[0] = input
        
    ComputeLeftChild(tree, 0, True)
    ComputeRightChild(tree, 0, True)                     
    return tree

In [64]:
m = 2 
p = 2**m # window(patch) size

In [65]:
input = np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]])
#input = np.random.randint(255, size=(p, p))
print(input)

[[ 1  2  3  4]
 [ 5  6  7  8]
 [ 9 10 11 12]
 [13 14 15 16]]


In [66]:
tree = WHTTree(input)

In [67]:
np.vstack(tree[-2**(2*m):]).flatten() # All the leaf nodes in the full binary tree to a flat array

array([  0,   0,   0,   0,   0,   0,   0,  -8,   0,   0,   0, -32,   0,
       -16, -64, 136])

In [None]:
# Test in 800*600 image

In [69]:
IMAGE_WIDTH = 800
IAMGE_HEIGHT = 600

sourceImage = np.random.randint(255, size=(IMAGE_WIDTH, IAMGE_HEIGHT))

for i in xrange(p, IAMGE_HEIGHT - p):
    for j in xrange(p, IMAGE_WIDTH - p):
        input = sourceImage[i-p:i+p, j-p:j+p]

In [72]:
input = sourceImage[20-3:20+3, 15-3:15+3]

In [73]:
input

array([[221, 232, 190, 127,  87, 220],
       [ 97,  34, 200, 200,  99,  63],
       [ 57,  54,  13, 233,  91, 252],
       [199,  15, 230, 117, 209,  89],
       [ 83, 125,  18, 162,  96,  98],
       [233,   6, 125, 141, 125, 176]])