In [141]:
import torch
import torch.nn as nn

In [142]:
def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


In [143]:

class WindowPartition2(nn.Module):
    """
    Utility module for partitioning and reversing windows in a patch grid.

    Input shape: (B, H, W, *embed_dims)
    After partitioning with a given window_size, the tensor is reshaped into:
        (B, H//window_size, W//window_size, window_size, window_size, *embed_dims)
    """
    def __init__(self, window_size: int):
        super(WindowPartition2, self).__init__()
        self.window_size = window_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Partition the input tensor into non-overlapping windows.

        Args:
            x (torch.Tensor): Input tensor of shape (B, H, W, *embed_dims).

        Returns:
            torch.Tensor: Partitioned tensor with shape 
                (B, H//window_size, W//window_size, window_size, window_size, *embed_dims).
        """
        B, H, W, *embed_dims = x.shape
        ws = self.window_size
        if H % ws != 0 or W % ws != 0:
            raise ValueError(f"H and W must be divisible by window_size {ws}. Got H={H}, W={W}.")
        # Reshape to split H and W into windows.
        x = x.view(B, H // ws, ws, W // ws, ws, *embed_dims)
        # Permute to group the window blocks together.
        windows = x.permute(0, 1, 3, 2, 4, *range(5, x.dim()))
        return windows

    def reverse(self, windows: torch.Tensor, H: int, W: int) -> torch.Tensor:
        """
        Reverse the window partition to reconstruct the original tensor.

        Args:
            windows (torch.Tensor): Partitioned windows with shape 
                (B, H//window_size, W//window_size, window_size, window_size, *embed_dims).
            H (int): Original height.
            W (int): Original width.

        Returns:
            torch.Tensor: Reconstructed tensor of shape (B, H, W, *embed_dims).
        """
        ws = self.window_size
        B, num_h, num_w, ws1, ws2, *embed_dims = windows.shape
        # Permute back to interleave the window dimensions.
        x = windows.permute(0, 1, 3, 2, 4, *range(5, windows.dim())).contiguous()
        # Reshape to reconstruct the original feature map.
        x = x.view(B, num_h * ws1, num_w * ws2, *embed_dims)
        return x


In [144]:

B, H, W, C = 2, 56,56, 48  # Batch size = 2, height = 16, width = 16, channels = 3
x = torch.randn(B, H, W, C)  # Random tensor for example
x_tensorized = x.view(B, H, W, 4, 4, 3)

# Define the window size
window_size =7  # You can change this value depending on your model's configuration

# Step 1: Partition the input tensor into windows
windows = window_partition(x, window_size)
print("Shape of windows after partitioning:", windows.shape)

# Step 3: Reverse the windows back to the original shape
reconstructed_x = window_reverse(windows, window_size, H, W)
print("Shape after reversing the windows:", reconstructed_x.shape)

# Ensure the reconstructed tensor matches the original shape
print(reconstructed_x.shape == x.shape)

Shape of windows after partitioning: torch.Size([128, 7, 7, 48])
Shape after reversing the windows: torch.Size([2, 56, 56, 48])
True


In [145]:
# Initialize the window partition module
partitioner = WindowPartition2(window_size)

# Partition the windows
windows_tensorized = partitioner(x_tensorized)

print("Partitioned shape:", windows_tensorized.shape)
# Expected: (B, H//ws, W//ws, ws, ws, 4, 4, 3)

# Reverse it to get back the original shape
# x_reconstructed = partitioner.reverse(windows_tensorized, H, W)

# print("Reconstructed shape:", x_reconstructed.shape)
# Expected: (B, H, W, 4, 4, 3)

# Check if reconstruction is correct
# print("Reconstruction successful:", torch.allclose(x, x_reconstructed))

Partitioned shape: torch.Size([2, 8, 8, 7, 7, 4, 4, 3])


In [146]:
test_tensorized_windows = windows_tensorized[0,0,0,:,:,:,:,:].reshape(7,7,-1)
test_windows = windows[:,:,:,:]   

print(test_tensorized_windows.shape)

print(test_windows.shape)

print(torch.equal(test_tensorized_windows , test_windows))

torch.Size([7, 7, 48])
torch.Size([128, 7, 7, 48])
False
