In [7]:
import torch

In [25]:
def unfold(input: torch.Tensor,
           window_size: tuple): #torch.Tensor
    """
    Unfolds (non-overlapping) a given feature map by the given window size (stride = window size)
    :param input: (torch.Tensor) Input feature map of the shape [batch size, channels, height, width]
    :param window_size: (int) Window size to be applied
    :return: (torch.Tensor) Unfolded tensor of the shape [batch size * windows, channels, window size, window size]
    """
    # Get original shape
    _, channels, height, width = input.shape  
    # type: int, int, int, int
    # Unfold input
    output: torch.Tensor = input.unfold(dimension=3, size=window_size[1], step=window_size[1]) \
        .unfold(dimension=2, size=window_size[0], step=window_size[0])
    # Reshape to [batch size * windows, channels, window size, window size]
    output: torch.Tensor = output.permute(0, 2, 3, 1, 5, 4).reshape(-1, channels, window_size[0], window_size[1])
    # (B* windows, channel, H_winsize, W_winsize)
    return output

In [27]:
dummy=torch.randn(4, 3, 192, 640)
print(dummy[0][0])
win_size=(6,20)
result=unfold(dummy, win_size)
print(result.shape)

tensor([[ 0.2551,  1.9334,  0.3962,  ...,  0.7247,  0.5572,  0.7516],
        [-0.2861, -2.3457,  0.6616,  ..., -2.9714, -1.3034, -0.2616],
        [ 0.6227, -0.3296, -0.5444,  ..., -2.6347,  0.3530,  0.4950],
        ...,
        [-0.7737,  0.7538,  0.7473,  ..., -0.9864, -2.0080, -0.4005],
        [-0.4337,  2.3665,  0.1365,  ..., -0.0404,  0.4469, -1.0898],
        [-2.5206, -0.7545,  1.9392,  ..., -0.2799, -0.2950,  0.7595]])
torch.Size([4096, 3, 6, 20])


In [29]:
def fold(input: torch.Tensor,
         window_size: tuple,
         height: int,
         width: int):# torch.Tensor
    """
    Fold a tensor of windows again to a 4D feature map
    :param input: (torch.Tensor) Input tensor of windows [batch size * windows, channels, h_window size, w_window size]
    :param window_size: (int) Window size to be reversed
    :param height: (int) Height of the feature map
    :param width: (int) Width of the feature map
    :return: (torch.Tensor) Folded output tensor of the shape [batch size, channels, height, width]
    """
    # Get channels of windows
    channels: int = input.shape[1]
    # Get original batch size
    batch_size: int = int(input.shape[0] // (height * width // window_size[0] // window_size[1]))
    # Reshape input to
    output: torch.Tensor = input.view(batch_size, height // window_size[0], width // window_size[1], channels,
                                      window_size[0], window_size[1])
    output: torch.Tensor = output.permute(0, 3, 1, 4, 2, 5).reshape(batch_size, channels, height, width)
    return output

In [32]:
recover=fold(result, (6, 20), 192, 640)
print(recover.shape)
print(recover[0][0])

torch.Size([4, 3, 192, 640])
tensor([[ 0.2551,  1.9334,  0.3962,  ...,  0.7247,  0.5572,  0.7516],
        [-0.2861, -2.3457,  0.6616,  ..., -2.9714, -1.3034, -0.2616],
        [ 0.6227, -0.3296, -0.5444,  ..., -2.6347,  0.3530,  0.4950],
        ...,
        [-0.7737,  0.7538,  0.7473,  ..., -0.9864, -2.0080, -0.4005],
        [-0.4337,  2.3665,  0.1365,  ..., -0.0404,  0.4469, -1.0898],
        [-2.5206, -0.7545,  1.9392,  ..., -0.2799, -0.2950,  0.7595]])


In [38]:
x=(6,20)
y=(x[0]//2, x[1]//2)
y

(3, 10)

In [1]:
for i in range(4):
    print(i)

0
1
2
3
