In [1]:
from gradflow import Tensor
import gradflow.functions as F

import torch
import torch.nn as nn

import numpy as np
from numpy.lib.stride_tricks import as_strided

In [2]:
def max_pool(image: np.ndarray, kernel_size: tuple[int, int], stride: int) -> np.ndarray:
  bs, c, w, h = image.shape
  kw, kh = kernel_size
  nb = image.dtype.itemsize

  new_w = (w - kw) // stride + 1
  new_h = (h - kh) // stride + 1

  patches = as_strided(image, shape=(bs, c, new_w, new_h, kw, kh),
                              strides=(
                                c*w*h*nb,
                                h*w*nb,
                                stride*w*nb, stride*nb, w*nb, nb))

  return patches, patches.max((-2, -1))

In [3]:
def max_pool_backward(im, patches, max_patches):
  # Max backward
  indices = np.where(patches == max_patches[:,:,:,:,None,None])
  grad = np.zeros_like(patches)
  grad[indices] = 1 # passing_grad.ravel() o * passing_grad

  print("mn fw max old shape", patches.shape)
  print("mn fw max old strides", patches.strides)
  print("mn bk max new_grad shape", grad.shape)
  print("mn bk max new_grad strides", grad.strides)
  print()
  # print(grad.shape)

  # as strided backward

  print("im shape", im.shape)
  print("im strides", im.strides)
  re_strided_grad = as_strided(grad, im.shape, im.strides)
  print("reim shape", re_strided_grad.shape)
  print("reim strides", re_strided_grad.strides)


  # return re_strided_grad

In [4]:
im = np.arange(1, 49).reshape(1, 3, 4, 4).repeat(2, 0)
kernel = np.array([1,2,1,2,1,2,1,2,1,2,1,2]).reshape(1, 3, 2, 2).repeat(4, 0)
patches, maxpool = max_pool(im, (2, 2), 2)

In [5]:
# Bug
# patches.max(), np.expand_dims(maxpool, None)

In [6]:
# max_pool_backward(im, patches, maxpool)

In [7]:
# as_strided(patches, im.shape, im.strides)

In [9]:
mp = F.MaxPool2D(2, 2)

gim = Tensor(im, requires_grad=True, dtype=np.float64)
mpool = mp(gim)
f = mpool.mean()
f.backward()

# gim.grad.astype(np.float16)

# mn fw max old shape (2, 3, 2, 2, 2, 2)
# mn fw max old strides (384, 128, 64, 16, 32, 8)

fw strides Old shape (2, 3, 4, 4)
fw strides Old strides (384, 128, 32, 8)
strides Old dtype float64
fw strides New shape (2, 3, 2, 2, 2, 2)
fw strides New strides (384, 128, 64, 16, 32, 8)

max old shape (2, 3, 2, 2, 2, 2)
max old strides (192, 64, 32, 8, 16, 4)
fw max new shape (2, 3, 2, 2)
fw max new strides (48, 16, 8, 4)

fw sum shape (2, 3, 2, 2)
fw sum strides (48, 16, 8, 4)
bk grad shape (2, 3, 2, 2)
bk grad strides (48, 16, 8, 4)

bk max grad shape (2, 3, 2, 2)
bk max grad strides (48, 16, 8, 4)
bk max new_grad shape (2, 3, 2, 2, 2, 2)
bk max new_grad strides (192, 64, 32, 8, 16, 4)

bk strides Old shape (2, 3, 4, 4)
bk strides Old strides (384, 128, 32, 8)
bk strides grad shape (2, 3, 2, 2, 2, 2)
bk strides grad strides (192, 64, 32, 8, 16, 4)
strides grad dtype float32



In [20]:
# gim.grad

In [11]:
# mpool

In [38]:


o = F.Conv2d(3, 4, 2, 1, bias=False)
o.weight = Tensor(kernel, requires_grad=True)
oz = o(Tensor(im))
oz = oz.mean()
oz.backward()
o.weight.grad.shape, o.weight.grad.strides

fw strides Old shape (2, 3, 4, 4)
fw strides Old strides (192, 64, 16, 4)
strides Old dtype float32
fw strides New shape (2, 3, 3, 3, 2, 2)
fw strides New strides (64, 16, 4, 64, 16, 4)

fw sum shape (2, 4, 3, 3)
fw sum strides (144, 36, 12, 4)
bk grad shape (2, 4, 3, 3)
bk grad strides (144, 36, 12, 4)



((4, 3, 2, 2), (4, 64, 32, 16))

In [32]:
tim = torch.tensor(im, dtype=torch.float32)
conv = nn.Conv2d(3, 4, 2, stride=1, bias=False)
maxpool = nn.MaxPool2d((2,2), 2)
conv.weight = nn.Parameter(torch.tensor(kernel, dtype=torch.float32))
# conv.bias = nn.Parameter(torch.tensor(o.bias.data.reshape(-1), dtype=torch.float32))

z = conv(tim)
z = z.mean()
z.backward()

# gtim = torch.tensor(im, dtype=torch.float32, requires_grad=True)
# tpool = maxpool(gtim)
# tf = tpool.mean()
# tf.backward()

In [46]:
conv.weight.grad.numpy().ravel()

array([ 1.5      ,  1.75     ,  2.5000002,  2.75     ,  5.5      ,
        5.75     ,  6.5      ,  6.75     ,  9.5      ,  9.75     ,
       10.5      , 10.75     ,  1.5      ,  1.75     ,  2.5000002,
        2.75     ,  5.5      ,  5.75     ,  6.5      ,  6.75     ,
        9.5      ,  9.75     , 10.5      , 10.75     ,  1.5      ,
        1.75     ,  2.5000002,  2.75     ,  5.5      ,  5.75     ,
        6.5      ,  6.75     ,  9.5      ,  9.75     , 10.5      ,
       10.75     ,  1.5      ,  1.75     ,  2.5000002,  2.75     ,
        5.5      ,  5.75     ,  6.5      ,  6.75     ,  9.5      ,
        9.75     , 10.5      , 10.75     ], dtype=float32)

In [45]:
o.weight.grad.ravel()

array([3.5 , 3.75, 4.5 , 4.75, 7.5 , 7.75, 8.5 , 8.75, 5.5 , 5.75, 6.5 ,
       6.75, 3.5 , 3.75, 4.5 , 4.75, 7.5 , 7.75, 8.5 , 8.75, 5.5 , 5.75,
       6.5 , 6.75, 3.5 , 3.75, 4.5 , 4.75, 7.5 , 7.75, 8.5 , 8.75, 5.5 ,
       5.75, 6.5 , 6.75, 3.5 , 3.75, 4.5 , 4.75, 7.5 , 7.75, 8.5 , 8.75,
       5.5 , 5.75, 6.5 , 6.75], dtype=float32)

In [34]:
# gtim.grad

In [15]:
# np.allclose(mpool, tpool.detach().numpy())

In [16]:
np.allclose(oz.data, z.detach().numpy())

NameError: name 'oz' is not defined

In [None]:
np.allclose(o.weight.data, conv.weight.detach().numpy())

In [None]:
np.allclose(o.weight.grad, conv.weight.grad.numpy())

In [None]:
o.bias

In [None]:
conv.bias

In [None]:
np.allclose(o.bias.data.reshape(-1), conv.bias.detach().numpy())