In [1]:
import torch

def generate_2d_attention_mask(H=8, W=8, window_size=4, shift_size=2):
    """
    Generates a 2D attention mask, ending with shape:
      [1, num_win_h, num_win_w, 1, 1, 1, window_size*window_size, window_size*window_size]
    """

    # --------------------------------------------------------------------------
    # 1) Create a label mask [1, H, W, 1] and fill it using "shifted" slices
    # --------------------------------------------------------------------------
    img_mask = torch.zeros((1, H, W, 1))

    h_slices = (
        slice(0, -window_size),
        slice(-window_size, -shift_size),
        slice(-shift_size,  None)
    )
    w_slices = (
        slice(0, -window_size),
        slice(-window_size, -shift_size),
        slice(-shift_size,  None)
    )

    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1

    # --------------------------------------------------------------------------
    # 2) Window partition
    # --------------------------------------------------------------------------
    def window_partition_mask(x, window_size):
        """
        Partitions x into non-overlapping windows of size (window_size x window_size).
        Returns:
            windows: shape [num_windows * B, window_size, window_size, C]
        """
        B, H_, W_, C = x.shape
        # Reshape into windows
        x = x.view(
            B,
            H_ // window_size, window_size,
            W_ // window_size, window_size,
            C
        )
        # Permute to group each window in its own batch dimension
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        windows = x.view(-1, window_size, window_size, C)
        return windows

    # Partition into windows
    mask_windows = window_partition_mask(img_mask, window_size)
    # mask_windows => [num_windows, window_size, window_size, 1]

    # Flatten each window
    mask_windows = mask_windows.view(-1, window_size * window_size)
    # => [num_windows, window_size*window_size]

    # --------------------------------------------------------------------------
    # 3) Build the attention mask by comparing window-patch labels
    # --------------------------------------------------------------------------
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0))
    attn_mask = attn_mask.masked_fill(attn_mask == 0, float(0.0))

    # --------------------------------------------------------------------------
    # 4) Reshape to the final shape [1, num_win_h, num_win_w, 1, 1, 1, 49, 49]
    #    for the 56x56, window_size=7 example.
    # --------------------------------------------------------------------------
    num_win_h = H // window_size
    num_win_w = W // window_size

    # 4a) First add a batch dimension
    attn_mask = attn_mask.unsqueeze(0)  # => [1, num_windows, window_size*window_size, window_size*window_size]

    # 4b) Reshape num_windows into (num_win_h, num_win_w)
    attn_mask = attn_mask.reshape(
        1,
        num_win_h,
        num_win_w,
        window_size * window_size,
        window_size * window_size
    )
    # => [1, num_win_h, num_win_w, 49, 49] for window_size=7

    # 4c) Finally, insert three singleton dimensions in the middle
    #     => [1, num_win_h, num_win_w, 1, 1, 1, 49, 49]
    attn_mask = attn_mask.unsqueeze(3).unsqueeze(4).unsqueeze(5)

    return attn_mask


if __name__ == "__main__":
    # Example usage
    attn_mask_function = generate_2d_attention_mask(
        H=8, W=8, window_size=4, shift_size=2
    )
    print("Final attention mask shape:", attn_mask_function.shape)



Final attention mask shape: torch.Size([1, 2, 2, 1, 1, 1, 16, 16])
