In [39]:
import torch
import torch.nn as nn

In [40]:
def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


In [41]:
import torch
import torch.nn as nn
import sys
sys.path.append("..")
from Tensorized_components.Window_partition import WindowPartition

class ShiftedWindowPartition(nn.Module):
    """
    Utility module for partitioning and reversing windows with a spatial shift applied.
    
    This class performs a spatial shift before partitioning the input tensor into windows.
    After partitioning and reverse operations, the spatial shift is compensated for.
    
    Input shape: (B, H, W, *embed_dims)
    """
    def __init__(self, window_size: int, shift_size: int):
        super(ShiftedWindowPartition, self).__init__()
        self.window_size = window_size
        self.shift_size = shift_size
        self.window_partition = WindowPartition(window_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Shift the input tensor and partition it into windows.

        Args:
            x (torch.Tensor): Input tensor of shape (B, H, W, *embed_dims).

        Returns:
            torch.Tensor: Partitioned windows after spatial shift.
        """
        # Apply negative shift to the spatial dimensions.
        x_shifted = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        # Partition the shifted tensor using WindowPartition.
        windows = self.window_partition(x_shifted)
        return windows

    def reverse(self, windows: torch.Tensor, H: int, W: int) -> torch.Tensor:
        """
        Reverse the partition and then reverse the spatial shift to reconstruct the original tensor.

        Args:
            windows (torch.Tensor): Partitioned windows with shape 
                (B, H//window_size, W//window_size, window_size, window_size, *embed_dims).
            H (int): Original height.
            W (int): Original width.

        Returns:
            torch.Tensor: Reconstructed tensor after compensating for the spatial shift.
        """
        # Reverse partition to get back the shifted tensor.
        x_reconstructed = self.window_partition.reverse(windows, H, W)
        # Roll back the tensor to reverse the spatial shift.
        x_final = torch.roll(x_reconstructed, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        return x_final


In [56]:

B, H, W, C = 1, 56,56, 48  # Batch size = 2, height = 16, width = 16, channels = 3
x = torch.randn(B, H, W, C)  # Random tensor for example
x_tensorized = x.view(B, H, W, 4, 4, 3)

# Define the window size
window_size =7  # You can change this value depending on your model's configuration

shift_size = 3

# Step 1: Partition the input tensor into windows

shifted_x = torch.roll(x, shifts=(-3, -3), dims=(1, 2))

windows = window_partition(shifted_x, window_size)
print("Shape of windows after partitioning:", windows.shape)

# Step 3: Reverse the windows back to the original shape
reconstructed_x = window_reverse(windows, window_size, H, W)
x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))

print("Shape after reversing the windows:", reconstructed_x.shape)

# Ensure the reconstructed tensor matches the original shape
print(reconstructed_x.shape == x.shape)

Shape of windows after partitioning: torch.Size([64, 7, 7, 48])
Shape after reversing the windows: torch.Size([1, 56, 56, 48])
True


In [58]:
# Initialize the window partition module
partitioner = ShiftedWindowPartition(window_size , shift_size)

# Partition the windows
windows_tensorized = partitioner(x_tensorized)

reverse  =partitioner.reverse(windows_tensorized, H, W)

print("Partitioned shape:", windows_tensorized.shape)

print("Reconstructed shape:", reverse.shape)

print(torch.eq(reverse ,x_tensorized))

Partitioned shape: torch.Size([1, 8, 8, 7, 7, 4, 4, 3])
Reconstructed shape: torch.Size([1, 56, 56, 4, 4, 3])
tensor([[[[[[True, True, True],
            [True, True, True],
            [True, True, True],
            [True, True, True]],

           [[True, True, True],
            [True, True, True],
            [True, True, True],
            [True, True, True]],

           [[True, True, True],
            [True, True, True],
            [True, True, True],
            [True, True, True]],

           [[True, True, True],
            [True, True, True],
            [True, True, True],
            [True, True, True]]],


          [[[True, True, True],
            [True, True, True],
            [True, True, True],
            [True, True, True]],

           [[True, True, True],
            [True, True, True],
            [True, True, True],
            [True, True, True]],

           [[True, True, True],
            [True, True, True],
            [True, True, True],
            

In [55]:
test_tensorized_windows = windows_tensorized[0,1,3,:,:,:,:,:].reshape(7,7,-1)
test_windows = windows[11,:,:,:]   

print(test_tensorized_windows.shape)

print(test_windows.shape)

print(torch.equal(test_tensorized_windows , test_windows))

torch.Size([7, 7, 48])
torch.Size([7, 7, 48])
True
