# Understanding Reduction Computation

## Placing the reduce axis at the end

In [489]:
from shrimpgrad.util import prod

class Tensor:
    def __init__(self, shape, data):
        self.shape = shape
        self.data = data
        self.strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
    
    def permute(self, order):
        self.shape = tuple([self.shape[i] for i in order])
        self.strides = tuple([self.strides[i] for i in order])

t = Tensor((2,2,2), [1,2,3,4]*2)
t.shape, t.strides

axis=(0,)
shape = t.shape
out_shape = (1,2,2)
out = Tensor(out_shape, [0]*prod(out_shape))

order = tuple([i for i,s in enumerate(shape) if shape[i] == out_shape[i]] + [i for i,s in enumerate(shape) if out_shape[i] != shape[i]])
print(order)

t.permute(order)
out.permute(order)


(1, 2, 0)


In [491]:
t.shape, t.strides, out.shape, out.strides

((2, 2, 2), (2, 1, 4), (2, 2, 1), (2, 1, 4))

What does it mean for the computation when the reduce axis is always in the last dimension?

In [482]:
off=0
acc = 0
for m in range(t.shape[0]):
    moff = m * t.strides[0]
    for r in range(t.shape[1]):
        roff = r * t.strides[1]
        for c in range(t.shape[2]):
            coff = moff + roff + c*t.strides[2]
            # print(t.datax[coff])
            out.data[off] += t.data[coff]
            print(off, t.data[coff])
        off+=1
            
   
print(out.data)


0 1
0 1
1 2
1 2
2 3
2 3
3 4
3 4
[2, 4, 6, 8]


In [400]:
import numpy as np
x = np.array([1,2,3,4]*2).reshape((2,2,2))
print(x)

[[[1 2]
  [3 4]]

 [[1 2]
  [3 4]]]


In [428]:
x.sum((0,1), keepdims=True).shape

(1, 1, 2)

In [448]:
x.transpose(order).sum((-1,-2), keepdims=True).transpose((1,2,0))

array([[[ 8, 12]]])

## Merging dimensions

In [607]:
t = Tensor((2,2,2), [1,2,3,4]*2)
t.shape, t.strides

axis=(0,)
shape = t.shape
out_shape = (1,2,2)
out = Tensor(out_shape, [0]*prod(out_shape))
print(f"{shape = } ")
print(f"{axis = }")
order = tuple([i for i,s in enumerate(shape) if shape[i] == out_shape[i]] + [i for i,s in enumerate(shape) if out_shape[i] != shape[i]])
print(f"{order = }")

t.permute(order)
out.permute(order)

print(f"{t.strides = } {t.shape = } {out_shape = }")

off=0
for m in range(t.shape[0]*t.shape[1]): # Merge two dims (since shape[1]*stride[1] == stride[0])
    moff = m * t.strides[1]
    for c in range(t.shape[2]):
        coff = moff + c*t.strides[2]
        print(f"{coff = }")
        out.data[off] += t.data[coff]
    off+=1
print(f"{out.data = }")

out = Tensor(out_shape, [0]*prod(out_shape))
off=0
for m in range(t.shape[0]*t.shape[1]): # Merge two and unroll
    moff = m * t.strides[1]
    print(f"{moff = } {moff + t.strides[2] = }") 
    out.data[off] += t.data[moff] + t.data[moff + t.strides[2]] 
    off+=1
    
print("Merge and unroll")
print(f"{out.data = }")


shape = (2, 2, 2) 
axis = (0,)
order = (1, 2, 0)
t.strides = (2, 1, 4) t.shape = (2, 2, 2) out_shape = (1, 2, 2)
coff = 0
coff = 4
coff = 1
coff = 5
coff = 2
coff = 6
coff = 3
coff = 7
out.data = [2, 4, 6, 8]
moff = 0 moff + t.strides[2] = 4
moff = 1 moff + t.strides[2] = 5
moff = 2 moff + t.strides[2] = 6
moff = 3 moff + t.strides[2] = 7
Merge and unroll
out.data = [2, 4, 6, 8]


Below you can't merge 0 and 1 dims because the above condition doesn't hold, but you can merge dim 1 and dim 2.

In [609]:
t = Tensor((2,2,2), [1,2,3,4]*2)
t.shape, t.strides

axis=(0,1) # Two axis sum
shape = t.shape
out_shape = (1,1,2) # the out shape with multidim reduce
out = Tensor(out_shape, [0]*prod(out_shape))
print(f"{axis = }")
order = tuple([i for i,s in enumerate(shape) if shape[i] == out_shape[i]] + [i for i,s in enumerate(shape) if out_shape[i] != shape[i]])
print(f"{order = }")

t.permute(order)
out.permute(order)

print(f"{t.strides = } {t.shape = } {out_shape = }")

off=0
for m in range(t.shape[0]): # Can't merge 0,1 because shape[1]*strides[1] != strides[0]
    moff = m * t.strides[0]
    for c in range(t.shape[2]*t.shape[1]): # Can merge here because strides[1] == shapes[2] * strides[2]
        coff = moff + c*t.strides[2]       # 4 == 2*2
        print(f"{coff = }")
        out.data[off] += t.data[coff]
    off+=1
print(f"{out.data = }")





axis = (0, 1)
order = (2, 0, 1)
t.strides = (1, 4, 2) t.shape = (2, 2, 2) out_shape = (1, 1, 2)
coff = 0
coff = 2
coff = 4
coff = 6
coff = 1
coff = 3
coff = 5
coff = 7
out.data = [8, 12]
