In [1]:
import numpy as np

In [2]:
def ComputeLeftChild(tree, parent, computeCol=True):
    index = 2*parent+1
    if index >= len(tree):
        return
    
    delta = 2**int((int(np.log2(parent+1))/2)) # level = int(log2(index+1))     
    if computeCol:
        p = tree[parent].shape[0]
        tree[index] = np.zeros((p, p - delta), dtype=int)
        for i in xrange(0, p - delta):
            tree[index][:, i] = tree[parent][:, i] - tree[parent][:, i+delta]
            
    else:
        p = tree[parent].shape[1]
        tree[index] = np.zeros((p, p), dtype=int)
        for i in xrange(0, p):
            tree[index][i, :] = tree[parent][i, :] - tree[parent][i+delta, :]

    computeCol = not computeCol
    ComputeLeftChild(tree, index, computeCol)
    ComputeRightChild(tree, index, computeCol)
    return

In [3]:
def ComputeRightChild(tree, parent, computeCol=True):
    index = 2*parent+2
    if index >= len(tree):
        return
    
    delta = 2**int((int(np.log2(parent+1))/2)) # level = int(log2(index+1))
    if computeCol:
        p = tree[parent].shape[0]
        tree[index] = np.zeros((p, p - delta), dtype=int)
        for i in xrange(0, p - delta):
            tree[index][:, i] = tree[parent][:, i] + tree[parent][:, i+delta]
            
    else:
        p = tree[parent].shape[1]
        tree[index] = np.zeros((p, p), dtype=int)
        for i in xrange(0, p):            
            tree[index][i, :] = tree[parent][i, :] + tree[parent][i+delta, :] 
    
    computeCol = not computeCol
    ComputeLeftChild(tree, index, computeCol)
    ComputeRightChild(tree, index, computeCol)
    return

In [4]:
# Construct WHT compute tree
def WHTTree(input):
    # Check if input is valid
    m = np.log2(input.shape[0])
    if not m.is_integer:
        print("Error: WHTTree input shape is not power of 2!")
        return
    if not m == np.log2(input.shape[1]):
        print("Error: WHTTree input is not square matrix!")
        return
    
    h = 2*m # Levels of binary tree
    length = int(2**(h+1)-1) # Total length of full binary tree
    tree = np.empty(length, dtype=object)        
    tree[0] = input
        
    ComputeLeftChild(tree, 0, True)
    ComputeRightChild(tree, 0, True)                     
    return tree

In [21]:
m = 3 
p = 2**m # window(patch) size

In [22]:
#input = np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]])
input = np.random.randint(255, size=(p, p))
print(input)

[[189 249 215  36  80 106  98  38]
 [104  28 214   2 205  80 136 102]
 [175 157 174 120 174  34  54 142]
 [121 183 163 150  95  26  53 242]
 [ 37  70  56 137   7 107 253 188]
 [190  56  84  75  18 164  27   6]
 [186 241 204 103 234  40 252 236]
 [114  31 144 119 118 201  67 156]]


In [23]:
tree = WHTTree(input)

In [24]:
np.vstack(tree[-2**(2*m):]).flatten() # All the leaf nodes in the full binary tree to a flat array

array([ -302,   312,  -256,  -322,   608,  -322,  -154,  -388,   -90,
         104,  -100, -1074,   696,  -886,  -154,   152,  -570,  1004,
        -500,  -990, -1200,  -754,    78,   -68,    50,   196,   968,
         322,     0,   610,   494,   616,   652,   162,   618,  -536,
        -722,   980,  1228,  -214,  1072,   -30,   454,  -168,   902,
         224,  -644,   918,  -244,  -126,   394,   272,  -102,   496,
         360,  -226,  -328,  -134,   790, -1152,   842,   388,    24,  7866])

In [25]:
# Test in 800*600 image

In [26]:
IMAGE_WIDTH = 80
IAMGE_HEIGHT = 60

sourceImage = np.random.randint(255, size=(IAMGE_HEIGHT, IMAGE_WIDTH))

for i in xrange(0, IAMGE_HEIGHT - p):
    for j in xrange(0, IMAGE_WIDTH - p):
        input = sourceImage[i:i+p, j:j+p]
        # Can be simplified to better performance
        tree = WHTTree(input)
        np.vstack(tree[-2**(2*m):]).flatten()
        

In [None]:
# 3s perhaps
# Can be simplyfied 2**3 = 64, just compute the first(in a predefined order) 16 results 