<a href="https://colab.research.google.com/github/aminehd/Alice-Differentiable-Adventures/blob/main/Chapter_7/Conv2D_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PyTorch's Unfold operation.

## What's unfold operation

PyTorch's `unfold` operation extracts sliding patches of k X k size from an input tensor. It is first operation behind a convolution operation.

## Example 1: Single-Channel Unfold

```
Input Tensor (4×4):
┌─────┬─────┬─────┬─────┐
│ 1.0 │ 2.0 │ 3.0 │ 4.0 │
├─────┼─────┼─────┼─────┤
│ 5.0 │ 6.0 │ 7.0 │ 8.0 │
├─────┼─────┼─────┼─────┤
│ 9.0 │ 10.0│ 11.0│ 12.0│
├─────┼─────┼─────┼─────┤
│ 13.0│ 14.0│ 15.0│ 16.0│
└─────┴─────┴─────┴─────┘
```
Let's apply `unfold` with a 2×2 kernel and stride=1:

In [2]:
import torch
img = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])

We transform img.unsqueeze(0) to add a batch dimension and a channel dimension:


In [5]:
print(img.shape)
print(img.unsqueeze(0).shape)
print(img.unsqueeze(0).unsqueeze(0).shape)

torch.Size([4, 4])
torch.Size([1, 4, 4])
torch.Size([1, 1, 4, 4])


In [6]:
unfloded = torch.nn.functional.unfold(
    img.unsqueeze(0).unsqueeze(0),
    kernel_size=2,
    stride=1
)
print(unfloded.shape)

torch.Size([1, 4, 9])


### The Unfold Process

Unfold slides a 2×2 window across the input, extracting 9 patches (3×3 output positions):

```
Patch 0:             Patch 1:             Patch 2:
┌─────┬─────┐        ┌─────┬─────┐        ┌─────┬─────┐
│ 1.0 │ 2.0 │        │ 2.0 │ 3.0 │        │ 3.0 │ 4.0 │
├─────┼─────┤        ├─────┼─────┤        ├─────┼─────┤
│ 5.0 │ 6.0 │        │ 6.0 │ 7.0 │        │ 7.0 │ 8.0 │
└─────┴─────┘        └─────┴─────┘        └─────┴─────┘

Patch 3:             Patch 4:             Patch 5:
┌─────┬─────┐        ┌─────┬─────┐        ┌─────┬─────┐
│ 5.0 │ 6.0 │        │ 6.0 │ 7.0 │        │ 7.0 │ 8.0 │
├─────┼─────┤        ├─────┼─────┤        ├─────┼─────┤
│ 9.0 │ 10.0│        │ 10.0│ 11.0│        │ 11.0│ 12.0│
└─────┴─────┘        └─────┴─────┘        └─────┴─────┘

Patch 6:             Patch 7:             Patch 8:
┌─────┬─────┐        ┌─────┬─────┐        ┌─────┬─────┐
│ 9.0 │ 10.0│        │ 10.0│ 11.0│        │ 11.0│ 12.0│
├─────┼─────┤        ├─────┼─────┤        ├─────┼─────┤
│ 13.0│ 14.0│        │ 14.0│ 15.0│        │ 15.0│ 16.0│
└─────┴─────┘        └─────┴─────┘        └─────┴─────┘
```

### The Resulting Unfolded Tensor

The `unfold` operation converts these patches to columns in a new tensor:

```
Unfolded Tensor (4×9):
┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐
│ 1.0 │ 2.0 │ 3.0 │ 5.0 │ 6.0 │ 7.0 │ 9.0 │ 10.0│ 11.0│ ← Top-left of each patch
├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
│ 2.0 │ 3.0 │ 4.0 │ 6.0 │ 7.0 │ 8.0 │ 10.0│ 11.0│ 12.0│ ← Top-right of each patch
├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
│ 5.0 │ 6.0 │ 7.0 │ 9.0 │ 10.0│ 11.0│ 13.0│ 14.0│ 15.0│ ← Bottom-left of each patch
├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
│ 6.0 │ 7.0 │ 8.0 │ 10.0│ 11.0│ 12.0│ 14.0│ 15.0│ 16.0│ ← Bottom-right of each patch
└─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘
  P0    P1    P2    P3    P4    P5    P6    P7    P8   ← Patch number
```

 Each column represents one of 2X2 patches from input that flattend into a vector. This exaplains why the result has shape `[1,4,9]`:
 - 1 batch
 - Each column is a flattened 2X2 part of the input
 - The Top left cell of patch can be any of pixels in the input except for the last row and col. So 9 patches.

In [8]:
unfloded.shape

torch.Size([1, 4, 9])

## Example 2: Multi-Channel Unfold

We can work out a multi channel input. let's address your specific question about why you get 27 channels when unfolding a 3-channel image with a 3×3 kernel.

Consider a 3-channel input tensor with shape `[1, 3, 32, 32]`:

```
3-Channel Input:
Channel 0:                  Channel 1:                  Channel 2:
┌────┬────┬─────┬─   ─┬────┐ ┌────┬────┬─────┬─   ─┬────┐ ┌────┬────┬─────┬─   ─┬────┐
│ A0 │ B0 │ C0  │ ... │ Z0 │ │ A1 │ B1 │ C1  │ ... │ Z1 │ │ A2 │ B2 │ C2  │ ... │ Z2 │
├────┼────┼─────┼─   ─┼────┤ ├────┼────┼─────┼─   ─┼────┤ ├────┼────┼─────┼─   ─┼────┤
│ D0 │ E0 │ F0  │ ... │... │ │ D1 │ E1 │ F1  │ ... │... │ │ D2 │ E2 │ F2  │ ... │... │
├────┼────┼─────┼─   ─┼────┤ ├────┼────┼─────┼─   ─┼────┤ ├────┼────┼─────┼─   ─┼────┤
│ G0 │ H0 │ I0  │ ... │... │ │ G1 │ H1 │ I1  │ ... │... │ │ G2 │ H2 │ I2  │ ... │... │
├────┼────┼─────┼─   ─┼────┤ ├────┼────┼─────┼─   ─┼────┤ ├────┼────┼─────┼─   ─┼────┤
│... │... │ ... │ ... │... │ │... │... │ ... │ ... │... │ │... │... │ ... │ ... │... │
├────┼────┼─────┼─   ─┼────┤ ├────┼────┼─────┼─   ─┼────┤ ├────┼────┼─────┼─   ─┼────┤
│... │... │ ... │ ... │... │ │... │... │ ... │ ... │... │ │... │... │ ... │ ... │... │
└────┴────┴─────┴─   ─┴────┘ └────┴────┴─────┴─   ─┴────┘ └────┴────┴─────┴─   ─┴────┘
```

When we apply `unfold` with a 3×3 kernel, we extract patches from each spatial location in the input tensor:


In [9]:

img_3channel = torch.rand(1, 3, 32, 32)

unfloded = torch.nn.functional.unfold(
    img_3channel,
    kernel_size=3,
    stride=1
)
print(unfloded.shape)


torch.Size([1, 27, 900])


Let's look at the first patch (top-left 3×3 window) across all channels:

```
First Patch (3×3 window at position [0,0]):

Channel 0:         Channel 1:         Channel 2:
┌────┬────┬────┐   ┌────┬────┬────┐   ┌────┬────┬────┐
│ A0 │ B0 │ C0 │   │ A1 │ B1 │ C1 │   │ A2 │ B2 │ C2 │
├────┼────┼────┤   ├────┼────┼────┤   ├────┼────┼────┤
│ D0 │ E0 │ F0 │   │ D1 │ E1 │ F1 │   │ D2 │ E2 │ F2 │
├────┼────┼────┤   ├────┼────┼────┤   ├────┼────┼────┤
│ G0 │ H0 │ I0 │   │ G1 │ H1 │ I1 │   │ G2 │ H2 │ I2 │
└────┴────┴────┘   └────┴────┴────┘   └────┴────┴────┘
```

### Converting to the Unfolded Representation

When unfolded, each 3×3 patch from each channel gets flattened and concatenated to form a single column:

```
First Patch Unfolded (27 values):
┌────┐
│ A0 │
├────┤
│ B0 │
├────┤  Channel 0
│ C0 │  (9 values)
├────┤
│ D0 │
├────┤
│... │
├────┤
│ I0 │
├────┤
│ A1 │
├────┤
│ B1 │
├────┤  Channel 1
│ C1 │  (9 values)
├────┤
│ D1 │
├────┤
│... │
├────┤
│ I1 │
├────┤
│ A2 │
├────┤
│ B2 │
├────┤  Channel 2
│ C2 │  (9 values)
├────┤
│ D2 │
├────┤
│... │
├────┤
│ I2 │
└────┘
```



The unfold operation produces a tensor of shape `[1, 27, 1024]` where
- 1 = batch size
- 27 = channels x kernels x kernel = 3 x 3 x 3 : because we flatten a kernel X kernel window accross 3 channels.
- 1024 = number of patches = 32 * 32