In [None]:
def conv2d(x, w, b=None, padding=0, stride=1, dilation=1, groups=1, ctx=None):
  """A differentiable convolution of 2d tensors.

  Note: Read this following documentation regarding the output's shape.
  https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d

  Backward call:
    backward_fn: conv2d_backward
    args: y, x, w, b, padding, stride, dilation

  Args:
    x (torch.Tensor): The input tensor.
      Has shape `(batch_size, in_channels, height, width)`.
    w (torch.Tensor): The weight tensor.
      Has shape `(out_channels, in_channels, kernel_height, kernel_width)`.
    b (torch.Tensor): The bias tensor. Has shape `(out_channels,)`.
    padding (Tuple[int, int] or int, Optional): The padding in each dimension (height, width).
      Defaults to 0.
    stride (Tuple[int, int] or int, Optional): The stride in each dimension (height, width).
      Defaults to 1.
    dilation (Tuple[int, int] or int, Optional): The dilation in each dimension (height, width).
      Defaults to 1.
    groups (int, Optional): Number of groups. Defaults to 1.
    ctx (List, optional): The autograd context. Defaults to None.

  Returns:
    y (torch.Tensor): The output tensor.
      Has shape `(batch_size, out_channels, out_height, out_width)`.
  """
  assert w.size(0) % groups == 0, \
    f'expected w.size(0)={w.size(0)} to be divisible by groups={groups}'
  assert x.size(1) % groups == 0, \
    f'expected x.size(1)={x.size(1)} to be divisible by groups={groups}'
  assert x.size(1) // groups == w.size(1), \
    f'expected w.size(1)={w.size(1)} to be x.size(1)//groups={x.size(1)}//{groups}'

  # BEGIN SOLUTION
  # extract and parse input
  if b is None:
    b = torch.zeros(out_channels, dtype=x.dtype, device=x.device)
  #Extract input dimensions :
  batch_size, _, height, width = x.shape
  #Extract weight dimensions:
  out_channels, _, kernel_height, kernel_width = w.shape
  in_channels = w.shape[1] * groups
  padding_h, padding_w = _pair(padding) if type(padding) is int else padding
  stride_h, stride_w = _pair(stride) if type(stride) is int else stride
  dilation_h, dilation_w = _pair(dilation) if type(dilation) is int else dilation
  out_height = int(1 + ((height + 2*padding_h - dilation_h * (kernel_height - 1) - 1) / (stride_h)))
  out_width = int(1 + ((width + 2*padding_w - dilation_w * (kernel_width - 1) - 1) / (stride_w)))
  
  # unfold x and split to groups
  patches = unfold(x, (kernel_height, kernel_width),
                    dilation=dilation,
                    padding=padding, 
                    stride=stride)
  patches = patches.reshape(batch_size, in_channels, kernel_height*kernel_width, out_height*out_width)
  patches_group_split = torch.stack(patches.chunk(chunks=groups, dim=1), dim=1)
  patches_group_split = patches_group_split.reshape(batch_size, groups, int(in_channels/groups)*kernel_height*kernel_width, out_height*out_width)
  
  # split w to groups
  w_group_split = w.reshape(out_channels, int(in_channels/groups)*kernel_height*kernel_width)
  w_group_split = torch.stack(w_group_split.chunk(chunks=groups, dim=0))
  w_group_split = w_group_split.reshape(1, groups, int(out_channels/groups), int(in_channels/groups)*kernel_height*kernel_width)
  
  # compute convolution as matrix multiplication
  convolved_patches = w_group_split @ patches_group_split
  convolved_patches = convolved_patches.reshape(batch_size, out_channels, out_height, out_width)
  convolved_patches += b.reshape(1, -1, 1, 1)

  y = convolved_patches

  if ctx is not None:
    ctx += [(conv2d_backward, [y, x, w, b, padding, stride, dilation, groups])]
  
  return y
  # END SOLUTION


def conv2d_backward(y, x, w, b, padding, stride, dilation, groups):
  """Backward computation of `conv2d`.

  Propagates the gradients of `y` (in `y.grad`) to `x`, `w` and `b` (if `b` is not None),
  and accumulates them in `x.grad`, `w.grad` and `b.grad`, respectively.

  Args:
    y (torch.Tensor): The output tensor.
      Has shape `(batch_size, out_channels, out_height, out_width)`.
    x (torch.Tensor): The input tensor.
      Has shape `(batch_size, in_channels, height, width)`.
    w (torch.Tensor): The weight tensor.
      Has shape `(out_channels, in_channels, kernel_height, kernel_width)`.
    b (torch.Tensor): The bias tensor. Has shape `(out_channels,)`.
    padding (Tuple[int, int] or int, Optional): The padding in each dimension (height, width).
      Defaults to 0.
    stride (Tuple[int, int] or int, Optional): The stride in each dimension (height, width).
      Defaults to 1.
    dilation (Tuple[int, int] or int, Optional): The dilation in each dimension (height, width).
      Defaults to 1.
    groups (int, Optional): Number of groups. Defaults to 1.
  """
  # BEGIN SOLUTION
  # extract and parse input
  batch_size, _, height, width = x.shape
  out_channels, _, kernel_height, kernel_width = w.shape
  in_channels = w.shape[1] * groups
  padding_h, padding_w = _pair(padding) if type(padding) is int else padding
  stride_h, stride_w = _pair(stride) if type(stride) is int else stride
  dilation_h, dilation_w = _pair(dilation) if type(dilation) is int else dilation
  out_height = int(1 + ((height + 2*padding_h - dilation_h * (kernel_height - 1) - 1) / (stride_h)))
  out_width = int(1 + ((width + 2*padding_w - dilation_w * (kernel_width - 1) - 1) / (stride_w)))
  
  # unfold x and split to groups
  patches = unfold(x, (kernel_height, kernel_width),
                    dilation=dilation,
                    padding=padding, 
                    stride=stride)
  patches = patches.reshape(batch_size, in_channels, kernel_height*kernel_width, out_height*out_width)
  patches = patches.permute([0, 3, 1, 2])
  patches_group_split = torch.stack(patches.chunk(chunks=groups, dim=2), dim=0)
  patches_group_split = patches_group_split.reshape(groups, batch_size*out_height*out_width, int(in_channels/groups)*kernel_height*kernel_width)

  # b.grad
  b.grad += y.grad.sum([0, 2, 3])

  # w.grad
  y_grad_permuted = y.grad.permute([1, 0, 2, 3])
  y_grad_permuted = y_grad_permuted.reshape(out_channels, batch_size*out_height*out_width)
  y_grad_permuted_group_split = torch.stack(y_grad_permuted.chunk(chunks=groups, dim=0), dim=0)
  w_grad_permuted = y_grad_permuted_group_split.matmul(patches_group_split)
  w.grad += w_grad_permuted.reshape(out_channels, int(in_channels/groups), kernel_height, kernel_width)

  # patches.grad (unfolded x)
  y_grad_permuted_group_split = y_grad_permuted_group_split.permute([0, 2, 1])
  w_permuted = w.reshape(out_channels, -1)
  w_permuted_group_split = torch.stack(w_permuted.chunk(chunks=groups, dim=0), dim=0)
  patches_grad_group_split = y_grad_permuted_group_split.matmul(w_permuted_group_split)

  # x.grad (fold the reshaped patches.grad)
  patches_grad_group_split = patches_grad_group_split.reshape(groups, batch_size*out_height*out_width, int(in_channels/groups), kernel_height*kernel_width)
  patches_grad_group_split = patches_grad_group_split.permute([1, 0, 2, 3])
  patches_grad_group_split = patches_grad_group_split.reshape(batch_size, out_height*out_width, in_channels*kernel_height*kernel_width)
  patches_grad_group_split = patches_grad_group_split.permute([0, 2, 1])
  x.grad += fold(patches_grad_group_split, 
                output_size=(height, width), 
                kernel_size=(kernel_height, kernel_width),
                dilation=dilation,
                padding=padding, 
                stride=stride)
  # END SOLUTION

In [2]:
from torch.nn.modules.utils import _pair, _ntuple

_pair(1)

(1, 1)

In [28]:
from torch.nn.functional import unfold, fold
import torch

#make a 64*64 image with 3 channels and 1 batch:
x = torch.tensor([[[[1., 2., 3., 4.],
                    [5., 6., 7., 8.],
                    [9., 10., 11., 12.],
                    [13., 14., 15., 16.]],
                   [[17., 18., 19., 20.],
                    [21., 22., 23., 24.],
                    [25., 26., 27., 28.],
                    [29., 30., 31., 32.]],
                   [[33., 34., 35., 36.],
                    [37., 38., 39., 40.],
                    [41., 42., 43., 44.],
                    [45., 46., 47., 48.]]]])

w = torch.tensor([[[[1., 2.],
                    [3., 4.],
                    [5., 6.]],
                   [[7., 8.],
                    [9., 10.],
                    [11., 12.]],
                   [[13., 14.],
                    [15., 16.],
                    [17., 18.]]],
                  [[[19., 20.],
                    [21., 22.],
                    [23., 24.]],
                   [[25., 26.],
                    [27., 28.],
                    [29., 30.]],
                   [[31., 32.],
                    [33., 34.],
                    [35., 36.]]]])

out_channels, _, kernel_height, kernel_width = w.shape

padding=0
stride=1
dilation=1
groups = 1

#Extract input dimensions :
batch_size, _, height, width = x.shape
#Extract weight dimensions:
out_channels, _, kernel_height, kernel_width = w.shape
in_channels = w.shape[1] * groups

pad_h, pad_w = _pair(padding) if type(padding) is int else padding
stride_h, stride_w = _pair(stride) if type(stride) is int else stride
dil_h, dil_w = _pair(dilation) if type(dilation) is int else dilation


In [30]:
out_height = int(1 + ((height + 2 * pad_h - pad_w * (kernel_height - 1) - 1) / (stride_h)))
out_width = int(1 + ((width + 2 * pad_w - dil_w * (kernel_width - 1) - 1) / (stride_w)))

patches = unfold(x, (kernel_height, kernel_width),
                dilation=dilation,
                padding=padding, 
                stride=stride)

patches = patches.reshape(batch_size, in_channels, kernel_height*kernel_width, out_height*out_width)



RuntimeError: shape '[1, 3, 6, 12]' is invalid for input of size 108

In [43]:
import torch
from torch.nn.functional import unfold, fold


# Example input and parameters
batch_size, in_channels, height, width = 1, 3, 5, 5
out_channels, kernel_height, kernel_width = 2, 3, 3
stride, padding, dilation, groups = 1, 1, 1, 1

x = torch.arange(batch_size * in_channels * height * width, dtype=torch.float32).reshape(batch_size, in_channels, height, width)
w = torch.ones(out_channels, in_channels // groups, kernel_height, kernel_width)
b = torch.zeros(out_channels)

print("Input (x):", x.shape)
print(x)

# Padding, stride, and dilation as tuples
padding_h, padding_w = padding, padding
stride_h, stride_w = stride, stride
dilation_h, dilation_w = dilation, dilation

out_height = int((height + 2 * padding_h - dilation_h * (kernel_height - 1) - 1) / stride_h) + 1
out_width = int((width + 2 * padding_w - dilation_w * (kernel_width - 1) - 1) / stride_w) + 1

# Unfold x
patches = unfold(x, (kernel_height, kernel_width), dilation=(dilation_h, dilation_w), padding=(padding_h, padding_w), stride=(stride_h, stride_w))
print("\nUnfolded Input (patches):", patches.shape)
print(patches)

# Reshape patches for groups

patches_re = patches.reshape(batch_size, in_channels, kernel_height * kernel_width, out_height * out_width)
print("\nReshaped Patches for Groups:", patches_re.shape)
print(patches_re)

# Group splitting
patches_group_split = torch.stack(patches_re.chunk(chunks=groups, dim=1), dim=1)
print("\nGrouped Patches Split:", patches_group_split.shape)
print(patches_group_split)

# Weight splitting and reshaping
w_group_split = w.reshape(out_channels, in_channels * kernel_height * kernel_width)
print("\nReshaped Weights:", w_group_split.shape)
print(w_group_split)

# Convolution as matrix multiplication
convolved_patches = w_group_split @ patches_group_split.reshape(batch_size, groups, -1, out_height * out_width)
convolved_patches = convolved_patches.reshape(batch_size, out_channels, out_height, out_width)
print("\nConvolved Patches:", convolved_patches.shape)
print(convolved_patches)

# Add bias
convolved_patches += b.reshape(1, -1, 1, 1)
print("\nOutput with Bias Added:", convolved_patches.shape)
print(convolved_patches)




Input (x): torch.Size([1, 3, 5, 5])
tensor([[[[ 0.,  1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.,  9.],
          [10., 11., 12., 13., 14.],
          [15., 16., 17., 18., 19.],
          [20., 21., 22., 23., 24.]],

         [[25., 26., 27., 28., 29.],
          [30., 31., 32., 33., 34.],
          [35., 36., 37., 38., 39.],
          [40., 41., 42., 43., 44.],
          [45., 46., 47., 48., 49.]],

         [[50., 51., 52., 53., 54.],
          [55., 56., 57., 58., 59.],
          [60., 61., 62., 63., 64.],
          [65., 66., 67., 68., 69.],
          [70., 71., 72., 73., 74.]]]])

Unfolded Input (patches): torch.Size([1, 27, 25])
tensor([[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  2.,  3.,  0.,  5.,  6.,  7.,
           8.,  0., 10., 11., 12., 13.,  0., 15., 16., 17., 18.],
         [ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,
           9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
         [ 0.,  0.,  0.,  0.,  0.,  1.,  2.,  3.,  4.,  0.,

In [45]:
patches.shape

torch.Size([1, 27, 25])

In [50]:
patches_re.shape

torch.Size([1, 3, 9, 25])

In [53]:
patches_group_split

tensor([[[[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  2.,  3.,  0.,  5.,  6.,
             7.,  8.,  0., 10., 11., 12., 13.,  0., 15., 16., 17., 18.],
           [ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,
             8.,  9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
           [ 0.,  0.,  0.,  0.,  0.,  1.,  2.,  3.,  4.,  0.,  6.,  7.,  8.,
             9.,  0., 11., 12., 13., 14.,  0., 16., 17., 18., 19.,  0.],
           [ 0.,  0.,  1.,  2.,  3.,  0.,  5.,  6.,  7.,  8.,  0., 10., 11.,
            12., 13.,  0., 15., 16., 17., 18.,  0., 20., 21., 22., 23.],
           [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
            13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.],
           [ 1.,  2.,  3.,  4.,  0.,  6.,  7.,  8.,  9.,  0., 11., 12., 13.,
            14.,  0., 16., 17., 18., 19.,  0., 21., 22., 23., 24.,  0.],
           [ 0.,  5.,  6.,  7.,  8.,  0., 10., 11., 12., 13.,  0., 15., 16.,
            17., 18.,  

In [None]:
def max_pool2d(x, kernel_size, padding=0, stride=1, dilation=1, ctx=None):
  """A differentiable convolution of 2d tensors.

  Note: Read this following documentation regarding the output's shape.
  https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d

  Backward call:
    backward_fn: max_pool2d_backward
    args: y, x, padding, stride, dilation

  Args:
    x (torch.Tensor): The input tensor. Has shape `(batch_size, in_channels, height, width)`.
    kernel_size (Tuple[int, int] or int): The kernel size in each dimension (height, width).
    padding (Tuple[int, int] or int, Optional): The padding in each dimension (height, width).
      Defaults to 0.
    stride (Tuple[int, int] or int, Optional): The stride in each dimension (height, width).
      Defaults to 1.
    dilation (Tuple[int, int] or int, Optional): The dilation in each dimension (height, width).
      Defaults to 1.
    ctx (List, optional): The autograd context. Defaults to None.

  Returns:
    y (torch.Tensor): The output tensor.
      Has shape `(batch_size, in_channels, out_height, out_width)`.
  """
  batch_size, in_channels, height, width = x.shape
  kernel_y, kernal_x = kernel_size
  pad_y, pad_x = _pair(padding) if isinstance(padding, int) else padding
  step_y, step_x = _pair(stride) if isinstance(stride, int) else stride
  dil_y, dil_x = _pair(dilation) if isinstance(dilation, int) else dilation

  out_height = int(1 + ((height + 2 * pad_y - dil_y * (kernel_y - 1) - 1) / step_y))
  out_width = int(1 + ((width + 2 * pad_x - dil_x * (kernal_x - 1) - 1) / step_x))

  # Unfold x to patches
  input_matrix = unfold(x, (kernel_y, kernal_x),
                    dilation=dilation,
                    padding=padding,
                    stride=stride)
  input_matrix = input_matrix.reshape(batch_size, in_channels, kernel_y * kernal_x, out_height, out_width)

  # Take max over each patch
  y, index = input_matrix.max(dim=2)

  # Save context for the backward pass
  if ctx is not None:
      ctx += [(max_pool2d_backward, [y, x, index, kernel_size, padding, stride, dilation])]

  return y


def max_pool2d_backward(y, x, index, kernel_size, padding, stride, dilation):
  """Backward computation of `max_pool2d`.

  Propagates the gradients of `y` (in `y.grad`) to `x` and accumulates it in `x.grad`.

  Args:
    y (torch.Tensor): The output tensor.
      Has shape `(batch_size, in_channels, out_height, out_width)`.
    x (torch.Tensor): The input tensor.
      Has shape `(batch_size, in_channels, height, width)`.
    index (torch.Tensor): Auxilary tensor with indices of the maximum elements. You are
      not restricted to a specific format.
    kernel_size (Tuple[int, int] or int): The kernel size in each dimension (height, width).
    padding (Tuple[int, int] or int, Optional): The padding in each dimension (height, width).
      Defaults to 0.
    stride (Tuple[int, int] or int, Optional): The stride in each dimension (height, width).
      Defaults to 1.
    dilation (Tuple[int, int] or int, Optional): The dilation in each dimension (height, width).
      Defaults to 1.
  """
  # Extract and parse input
  batch_size, in_channels, height, width = x.shape
  batch_size, in_channels, out_height, out_width = y.shape
  kernel_y, kernal_x = _pair(kernel_size) if isinstance(kernel_size, int) else kernel_size

  # Convert indices to one-hot encoding
  one_hot_entries = one_hot(index, num_classes=kernel_y * kernal_x)
  one_hot_entries = one_hot_entries.permute([0, 1, 4, 2, 3])

  # Compute the gradient w.r.t. unfolded tensor
  permuted_x_grad = one_hot_entries * y.grad.unsqueeze(2)
  permuted_x_grad = permuted_x_grad.reshape(batch_size, in_channels * kernel_y * kernal_x, out_height * out_width)

  # Add gradients to x.grad using fold
  x_grad = fold(permuted_x_grad,
                output_size=(height, width),
                kernel_size=(kernel_y, kernal_x),
                dilation=dilation,
                padding=padding,
                stride=stride)
  
  # Accumulate the gradients
  x.grad += x_grad

In [58]:
def max_pool2d_v2(x, kernel_size, padding=0, stride=1, dilation=1, ctx=None):
  """A differentiable convolution of 2d tensors.

  Note: Read this following documentation regarding the output's shape.
  https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d

  Backward call:
    backward_fn: max_pool2d_backward
    args: y, x, padding, stride, dilation

  Args:
    x (torch.Tensor): The input tensor. Has shape `(batch_size, in_channels, height, width)`.
    kernel_size (Tuple[int, int] or int): The kernel size in each dimension (height, width).
    padding (Tuple[int, int] or int, Optional): The padding in each dimension (height, width).
      Defaults to 0.
    stride (Tuple[int, int] or int, Optional): The stride in each dimension (height, width).
      Defaults to 1.
    dilation (Tuple[int, int] or int, Optional): The dilation in each dimension (height, width).
      Defaults to 1.
    ctx (List, optional): The autograd context. Defaults to None.

  Returns:
    y (torch.Tensor): The output tensor.
      Has shape `(batch_size, in_channels, out_height, out_width)`.
  """
  # BEGIN SOLUTION
  # extract and parse input
  batch_size, in_channels, height, width = x.shape
  kernel_height, kernel_width = _pair(kernel_size) if type(kernel_size) is int else kernel_size
  padding_h, padding_w = _pair(padding) if type(padding) is int else padding
  stride_h, stride_w = _pair(stride) if type(stride) is int else stride
  dilation_h, dilation_w = _pair(dilation) if type(dilation) is int else dilation
  out_height = int(1 + ((height + 2*padding_h - dilation_h * (kernel_height - 1) - 1) / (stride_h)))
  out_width = int(1 + ((width + 2*padding_w - dilation_w * (kernel_width - 1) - 1) / (stride_w)))

  # unfold x to patches
  # patches.shape - (batch_size, C*Kh*Kw, out_height*out_width)
  patches = unfold(x, (kernel_height, kernel_width),
                    dilation=dilation,
                    padding=padding, 
                    stride=stride)
  patches = patches.reshape(batch_size, in_channels, kernel_height*kernel_width, out_height, out_width)

  # take max over each patch
  y, index = patches.max(dim=2)

  if ctx is not None:
    ctx += [(max_pool2d_backward, [y, x, index, kernel_size, padding, stride, dilation])]

  return y

In [59]:
# test both implementations:
x = torch.tensor([[[[1., 2., 3., 4.],
                    [5., 6., 7., 8.],
                    [9., 10., 11., 12.],
                    [13., 14., 15., 16.]],
                   [[17., 18., 19., 20.],
                    [21., 22., 23., 24.],
                    [25., 26., 27., 28.],
                    [29., 30., 31., 32.]],
                   [[33., 34., 35., 36.],
                    [37., 38., 39., 40.],
                    [41., 42., 43., 44.],
                    [45., 46., 47., 48.]]]])

kernel_size = 2
padding = 0

stride = 2
dilation = 1

y1 = max_pool2d_v1(x, kernel_size, padding, stride, dilation)
y2 = max_pool2d_v2(x, kernel_size, padding, stride, dilation)

print(y1)
print(y2)


RuntimeError: shape '[1, 3, 4, 4]' is invalid for input of size 12