In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn

In [3]:
class Python_MaxPool(object):

    @staticmethod
    def forward(x, pool_param, layer_no=[], save_txt=False, save_hex=False, phase=[]):

        out = None
        all_positions = []
        positions = []
        stride = pool_param['stride']
        pool_width = pool_param['pool_width']
        pool_height = pool_param['pool_height']
        N, C, H, W = x.shape
        H_out = int(1 + (H - pool_height) / stride)
        W_out = int(1 + (W - pool_width) / stride)
        out = torch.zeros((N, C, H_out, W_out), dtype=x.dtype, device=x.device)
        for n in range(N):
            for height in range(H_out):
                for width in range(W_out):
                    val, index = x[n, :, height * stride:height * stride + pool_height,
                             width * stride:width * stride + pool_width].reshape(C, -1).max(dim=1)
                    positions.append(index)
                    out[n, :, height, width] = val
            

        cache = (x, pool_param)
        positions = torch.cat(positions).squeeze().to(torch.int) # List into a Single Tensor
        new_positions = positions.reshape(N, C, H_out, W_out)

                        
        # Sign of Max Value in MaxPooling: 
        sign = torch.zeros_like(out)
        sign[out < 0] = 1


        return out, cache, positions, new_positions

{'pool_width': 2, 'pool_height': 2, 'stride': 2}

In [5]:
torch.manual_seed(42)
x = torch.randn(2, 2, 4, 4)
pool_param = {'pool_width': 2, 'pool_height': 2, 'stride': 2}
print(x)
y1, _cache, positions, new_positions = Python_MaxPool.forward(x, pool_param)
print(y1, '\n', positions, '\n', new_positions)

tensor([[[[ 1.9269,  1.4873,  0.9007, -2.1055],
          [ 0.6784, -1.2345, -0.0431, -1.6047],
          [-0.7521,  1.6487, -0.3925, -1.4036],
          [-0.7279, -0.5594, -0.7688,  0.7624]],

         [[ 1.6423, -0.1596, -0.4974,  0.4396],
          [-0.7581,  1.0783,  0.8008,  1.6806],
          [ 1.2791,  1.2964,  0.6105,  1.3347],
          [-0.2316,  0.0418, -0.2516,  0.8599]]],


        [[[-1.3847, -0.8712, -0.2234,  1.7174],
          [ 0.3189, -0.4245,  0.3057, -0.7746],
          [-1.5576,  0.9956, -0.8798, -0.6011],
          [-1.2742,  2.1228, -1.2347, -0.4879]],

         [[-0.9138, -0.6581,  0.0780,  0.5258],
          [-0.4880,  1.1914, -0.8140, -0.7360],
          [-1.4032,  0.0360, -0.0635,  0.6756],
          [-0.0978,  1.8446, -1.1845,  1.3835]]]])
tensor([[[[ 1.9269,  0.9007],
          [ 1.6487,  0.7624]],

         [[ 1.6423,  1.6806],
          [ 1.2964,  1.3347]]],


        [[[ 0.3189,  1.7174],
          [ 2.1228, -0.4879]],

         [[ 1.1914,  0.5258],
   

In [42]:
h_out, w_out = 3, 3

In [43]:
test2 = new_positions[:, :, :, 1].view(-1).tensor_split(3)
print(test2)

(tensor([0, 3, 1], dtype=torch.int32), tensor([1, 3, 1], dtype=torch.int32), tensor([3, 3], dtype=torch.int32))


In [44]:
test1 = new_positions[:, :, :, 0].view(-1).tensor_split(3)
print(test1)

(tensor([0, 0, 1], dtype=torch.int32), tensor([3, 2, 1], dtype=torch.int32), tensor([3, 3], dtype=torch.int32))


In [45]:
final = torch.cat([test1[0], test2[0], test1[1], test2[1]])

In [46]:
print(final)

tensor([0, 0, 1, 0, 3, 1, 3, 2, 1, 1, 3, 1], dtype=torch.int32)


In [47]:
torch.manual_seed(42)
x = torch.randn(2, 2, 4, 4)
pool_param = {'pool_width': 2, 'pool_height': 2, 'stride': 2}
print(x)
y1, _, positions, new_positions = Python_MaxPool.forward(x, pool_param)
print(y1, '\n', positions, '\n', new_positions)

tensor([[[[ 1.9269,  1.4873,  0.9007, -2.1055],
          [ 0.6784, -1.2345, -0.0431, -1.6047],
          [-0.7521,  1.6487, -0.3925, -1.4036],
          [-0.7279, -0.5594, -0.7688,  0.7624]],

         [[ 1.6423, -0.1596, -0.4974,  0.4396],
          [-0.7581,  1.0783,  0.8008,  1.6806],
          [ 1.2791,  1.2964,  0.6105,  1.3347],
          [-0.2316,  0.0418, -0.2516,  0.8599]]],


        [[[-1.3847, -0.8712, -0.2234,  1.7174],
          [ 0.3189, -0.4245,  0.3057, -0.7746],
          [-1.5576,  0.9956, -0.8798, -0.6011],
          [-1.2742,  2.1228, -1.2347, -0.4879]],

         [[-0.9138, -0.6581,  0.0780,  0.5258],
          [-0.4880,  1.1914, -0.8140, -0.7360],
          [-1.4032,  0.0360, -0.0635,  0.6756],
          [-0.0978,  1.8446, -1.1845,  1.3835]]]])
tensor([[[[ 1.9269,  0.9007],
          [ 1.6487,  0.7624]],

         [[ 1.6423,  1.6806],
          [ 1.2964,  1.3347]]],


        [[[ 0.3189,  1.7174],
          [ 2.1228, -0.4879]],

         [[ 1.1914,  0.5258],
   

In [63]:
torch.manual_seed(33)
x = torch.randn(3, 2, 6, 6)
pool_param = {'pool_width': 2, 'pool_height': 2, 'stride': 2}
print(x)
y1, _, positions, new_positions = Python_MaxPool.forward(x, pool_param)
print(y1, '\n', positions, '\n', new_positions)

tensor([[[[ 0.6014,  0.1087, -0.4499,  0.7841,  0.6550, -0.3062],
          [ 1.3935,  0.0631, -1.2514,  1.2745,  0.4777, -0.4516],
          [ 0.1392, -1.5146,  0.2888,  0.5565, -0.4415, -1.5009],
          [ 1.2530, -0.6902,  0.2395,  1.3365, -0.5728,  0.5368],
          [ 0.4898, -1.2608, -0.3192,  0.3207, -0.1966, -0.7767],
          [-1.4657, -1.0870, -1.0364,  0.3204, -1.0311,  0.7873]],

         [[-1.3096, -0.3365, -1.2034,  1.5051, -0.2974,  0.3986],
          [ 0.4957,  0.8755,  0.5996,  0.1281,  0.2540, -0.3113],
          [-1.1712, -1.0707,  1.5110, -1.4429, -0.3396, -1.6925],
          [-0.3451,  1.3243,  0.1380, -0.7729,  0.9246,  0.5713],
          [ 0.2663, -0.3831,  0.7125, -2.7313, -0.1294, -2.4284],
          [-1.8523,  0.0190, -1.2144,  0.8610,  0.2427,  0.9375]]],


        [[[ 0.3038,  1.2356,  0.6882, -0.6691, -0.6764, -0.6189],
          [ 0.6844, -0.1688,  1.6709, -1.4104, -0.1584,  1.3174],
          [ 1.8679, -0.2671,  0.0272, -1.2297, -0.7697, -0.0064],
    

In [62]:
new_positions = new_positions[:, :, :, 0].view(-1)
print(new_positions, new_positions.shape)

IndexError: too many indices for tensor of dimension 3

In [64]:
test1 = new_positions[:, :, :, 1].view(-1).tensor_split(3)
print(test1)

(tensor([3, 0, 3, 3, 0, 3], dtype=torch.int32), tensor([0, 3, 1, 1, 2, 2], dtype=torch.int32), tensor([0, 2, 3, 1, 0, 2], dtype=torch.int32))


In [65]:
test2 = new_positions[:, :, :, 1].view(-1).tensor_split(3)
print(test2)

(tensor([3, 0, 3, 3, 0, 3], dtype=torch.int32), tensor([0, 3, 1, 1, 2, 2], dtype=torch.int32), tensor([0, 2, 3, 1, 0, 2], dtype=torch.int32))


In [66]:
test3 = new_positions[:, :, :, 2].view(-1).tensor_split(3)
print(test3)

(tensor([3, 1, 3, 2, 1, 3], dtype=torch.int32), tensor([2, 1, 0, 0, 0, 1], dtype=torch.int32), tensor([0, 1, 3, 2, 0, 1], dtype=torch.int32))


In [None]:
final_pos = torch.cat([[k[i] for i in ]])

In [None]:
final = torch.cat([test1[0], test2[0], test3[0], test1[1], test2[1], test2[1], test3[0], test3[1], test3[2]])