In [1]:
import torch
from einops.layers.torch import Rearrange
from einops import repeat, rearrange
import torch.nn as nn
import cv2
import matplotlib.pyplot as plt
import numpy as np
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def read_image(img_src="/home/users/akshay.v/nano_models/images/guitar.jpg", img_size = (224,224)):
    image = cv2.imread(img_src)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, img_size, interpolation = cv2.INTER_CUBIC)
    image = image.astype('float32') / 255.0  # Normalize to [0, 1]
    image = torch.Tensor(image)
    image = image.permute(2,0,1) # C , H ,W
    image = image.unsqueeze(0) #to add the batch dimension
    print("image shape is " , image.shape)
    return image

In [3]:
class CPB_MLP(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(2, 512, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_heads, bias=False)
        )
        
    def forward(self, x):
        return self.layers(x)

In [18]:
window_size = [3, 3]  # Example window size
relative_coords = torch.stack(torch.meshgrid(
    torch.arange(-(window_size[0]-1), window_size[0]),
    torch.arange(-(window_size[1]-1), window_size[1])
)).permute(1, 2, 0).unsqueeze(0)


In [19]:
torch.arange(-(window_size[0]-1), window_size[0])

tensor([-2, -1,  0,  1,  2])

In [13]:
relative_coords.shape

torch.Size([1, 5, 5, 2])

In [20]:
cpb_mlp = CPB_MLP(num_heads=8)
relative_position_bias = cpb_mlp(relative_coords.view(-1, 2).float()).view(
    window_size[0]*2-1, window_size[1]*2-1, -1)

In [22]:
relative_position_bias.shape

torch.Size([5, 5, 8])

### References

1) https://bolster.ai/blog/swin-transformers
2) https://anyline.com/news/transformers-in-computer-vision
3) https://chautuankien.medium.com/explanation-swin-transformer-93e7a3140877
4) https://www.youtube.com/watch?v=LxPDpAiyqSU&list=PL9iXGo3xD8jokWaLB8ZHUkjjv5Y_vPQnZ&index=2&ab_channel=AIOpenCourseWare
5) https://medium.com/thedeephub/building-swin-transformer-from-scratch-using-pytorch-hierarchical-vision-transformer-using-shifted-91cbf6abc678
6) https://towardsdatascience.com/a-comprehensive-guide-to-swin-transformer-64965f89d14c
7) https://www.youtube.com/watch?v=SndHALawoag&ab_channel=AICoffeeBreakwithLetitia

The code is inspired from
1) https://github.com/berniwal/swin-transformer-pytorch
2) https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py

## Problems with VIT

if we take each each patch as 1*1 and then calculate the number of tokens , we get around 63000 tokens and for larger images of size 1024 it goes to over a million tokens

In [3]:
class Patchify(nn.Module):
    def __init__(self, patch_size=56, stride_size=56):
        super().__init__()
        self.p = patch_size
        self.unfold = torch.nn.Unfold(kernel_size=patch_size, stride=stride_size)

    def forward(self, x):
        # x -> B c h w
        bs, c, h, w = x.shape
        
        x = self.unfold(x)
        # x -> B (c*p*p) L
        print(x.shape)
        # Reshaping into the shape we want
        a = x.view(bs, c, self.p, self.p, -1).permute(0, 4, 1, 2, 3)
        # a -> ( B no.of patches c p p )
        return a

In [4]:
patch = Patchify(patch_size=1, stride_size=1)
img_src = "/home/users/akshay.v/nano_models/images/guitar.jpg"
image = cv2.imread(img_src)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (256, 256), interpolation = cv2.INTER_CUBIC)
image = image.astype('float32') / 255.0  # Normalize to [0, 1]
image = torch.Tensor(image)
image = image.permute(2,0,1) # C , H ,W
image = image.unsqueeze(0) #to add the batch dimension
print(image.shape)
p = patch(image)
print(p.shape)
p = p.squeeze() #to remove the batch dimension for plotting
print(p.reshape(1,p.shape[0],-1).shape)

torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 65536])
torch.Size([1, 65536, 3, 1, 1])
torch.Size([1, 65536, 3])


# Swin Transformers
Will be going through the swin easy code 
The swin offical code is very simiar to this

## Patch Merging

Here we see different ways of how patch merging is implemented , u might notice for ex (1, 96, 192, 284) in the code , these show in the changes in size as it passed through the layers

There are 3 ways its implemented
1) using unfold + linear layer
2) using a conv layer
3) using a patch embed and patch merge , in the case of the offical swin repo

I prefer the conv layer one

In [95]:
# using unfold + linear
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0) 
        # will just conver it from 1, 3, 224 ,224 into 1, (4*4*3), (56*56)
        # this will get resized to 1, 48, 56, 56 
        # if we take the above one and reshape it to 1, 56, 56 ,4 , 4, 3 we can print the entire image in patch of 4
        # this will get resized to 1, 56, 56, 48
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

    def forward(self, x):
        b, c, h, w = x.shape
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
        print(x.shape)
        x = self.linear(x) # [1, 56, 56, 96]
        return x
    
# using conv
class PatchMerging_Conv(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Conv2d(in_channels=in_channels, 
                                     out_channels=out_channels, 
                                     kernel_size=downscaling_factor, 
                                     stride=downscaling_factor, padding=0)
        
        """
        Each filter in this layer is a 4x4 filter that extends through the depth of the input volume. 
        Since the input volume has 3 channels, each filter actually has a dimension of 4x4x3 (width x height x depth). 
        There are 96 such filters, each producing one feature map, leading to an output volume with 96 channels.
        """
        

    def forward(self, x):
        b, c, h, w = x.shape # [1, (3, 96, 192, 384), (224, 56, 28, 14), (224, 56, 28, 14)]
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor # 224/4 = 56 # (56,28,17,7)
        print(self.patch_merge(x).shape)
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1) # # ( 1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)]
        print(type(self).__name__, "size of image after conv" , x.shape)
        return x

# official swin repo replementation   

class PatchEmbed_Official(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

    def flops(self):
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops
    
class PatchMerging_Official(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)
        print(x.shape)
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C

        ## basically here it 

        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        # conatentate along the long dimension
        print(x.shape)
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
        print(x.shape)
        x = self.norm(x)
        x = self.reduction(x)
        print(x.shape)
        return x

    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops

In [96]:
image = read_image()

image shape is  torch.Size([1, 3, 224, 224])


In [97]:
patch_merge = PatchMerging(3,96,4)
out = patch_merge(image)

torch.Size([1, 56, 56, 48])


In [98]:
patch_merge = PatchMerging_Conv(3,96,4)
out = patch_merge(image)

torch.Size([1, 96, 56, 56])
PatchMerging_Conv size of image after conv torch.Size([1, 56, 56, 96])


<img src="../images/patch_merging.png" alt="alt text" width="500" align="left"/>

In [84]:
# in the case of the official repo they do a patch embed which converts 1 , 3, 224, 224 into the image size of 1, 56, 56, 96
# which is then passed into a patch merging function unlike the other code  after the first swin block to downsample it to 1, 28, 28, 196
module = PatchEmbed_Official()
out_image = module(image)
print(out_image.shape)
patch_merge = PatchMerging_Official((56,56),96)
out = patch_merge(out_image)
print(out.shape)

torch.Size([1, 3136, 96])
torch.Size([1, 56, 56, 96])
torch.Size([1, 28, 28, 384])
torch.Size([1, 784, 384])
torch.Size([1, 784, 192])
torch.Size([1, 784, 192])


In [85]:
# to understand the stack operation 
test = torch.randn(1,28,28,96)
# basically stacks all of them along to the last dimension 
out_test = torch.cat([test,test,test,test], dim=-1)
out_test.shape

torch.Size([1, 28, 28, 384])

## Swin Transformer Block

we will only be focusing on the first layer

In [17]:
from modules.swin_easy import SwinBlock, PatchMerging
class StageModule(nn.Module):
    def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()
        assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'

        self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
                                            downscaling_factor=downscaling_factor)

        self.layers = nn.ModuleList([])
        for _ in range(layers // 2):
            self.layers.append(nn.ModuleList([
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
            ]))

    def forward(self, x):
        print(x.shape)
        x = self.patch_partition(x)
        print(x.shape)
        for regular_block, shifted_block in self.layers:
            x = regular_block(x)
            x = shifted_block(x)
        #print(x.shape)
        return x.permute(0, 3, 1, 2)
    

class SwinTransformer(nn.Module):
    def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()

        self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )

    def forward(self, img):
        # img = (1, 3, 224, 224)
        x = self.stage1(img)
        #After permute from stage block x = (1, 96, 56, 56)
        x = self.stage2(x)
        #After permute from stage block x = (1, 192, 28, 28)
        x = self.stage3(x)
        print(x.shape)
        #After permute from stage block x = (1, 384, 14, 14)
        x = self.stage4(x)
        print(x.shape)
        #After permute from stage block x = (1, 768, 7, 7)
        x = x.mean(dim=[2, 3])
        return self.mlp_head(x)


def swin_t(hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

In [18]:
model = swin_t()
out = model(image)

torch.Size([1, 3, 224, 224])
torch.Size([1, 56, 56, 96])
torch.Size([1, 96, 56, 56])
torch.Size([1, 28, 28, 192])
torch.Size([1, 192, 28, 28])
torch.Size([1, 14, 14, 384])
torch.Size([1, 384, 14, 14])
torch.Size([1, 384, 14, 14])
torch.Size([1, 7, 7, 768])
torch.Size([1, 768, 7, 7])


In [None]:
stage1 = StageModule(3,96,layers=2,downscaling_factor=4,num_heads=3,window_size=7,relative_pos_embedding=True)
stage2 = StageModule(96,96*2,layers=2,downscaling_factor=2,num_heads=6,window_size=7,relative_pos_embedding=True)
stage3 = StageModule(96*2,96*4,layers=6,downscaling_factor=2,num_heads=12,window_size=7,relative_pos_embedding=True)
stage4 = StageModule(96*4,96*8,layers=2,downscaling_factor=2,num_heads=24,window_size=7,relative_pos_embedding=True)

In [10]:
in_channels=3
hidden_dim = 96
layers=(2, 2, 6, 2)
heads=(3, 6, 12, 24)
head_dim=32
window_size=7
downscaling_factors=(4, 2, 2, 2)
relative_pos_embedding=True
stage1 = StageModule(in_channels=in_channels, hidden_dimension=hidden_dim, layers=layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
image = read_image()
out_1 = stage1(image)

image shape is  torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 56, 56, 96])
torch.Size([1, 56, 56, 96])


## Window Attention

In [9]:

from torch import einsum
from einops import rearrange

class CyclicShift(nn.Module):
    def __init__(self, displacement):
        super().__init__()
        self.displacement = displacement

    def forward(self, x):
        return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return self.net(x)


def create_mask(window_size, displacement, upper_lower, left_right):
    mask = torch.zeros(window_size ** 2, window_size ** 2) #49,49

    if upper_lower:
        mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
        mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

    if left_right:
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
        mask[:, -displacement:, :, :-displacement] = float('-inf')
        mask[:, :-displacement, :, -displacement:] = float('-inf')
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

    return mask


def get_relative_distances(window_size):
    indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
    distances = indices[None, :, :] - indices[:, None, :]
    return distances


class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
        """
        dim = hidden_dim = (96,192,384,768)
        heads = num_heads = (3,6,12,24)
        head_dim=32
        
        """
        super().__init__()
        inner_dim = head_dim * heads # (32*3 = 96, 32*6=192, 32*12=384, 32*24=768)

        self.heads = heads
        self.scale = head_dim ** -0.5 # scalling dot product inside the softmax
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted

        if self.shifted:
            displacement = window_size // 2
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                             upper_lower=True, left_right=False), requires_grad=False)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                            upper_lower=False, left_right=True), requires_grad=False)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)  
        # dim = (96,192,384,768)
        # inner_dim = head_dim * 3
        if self.relative_pos_embedding:
            self.relative_indices = get_relative_distances(window_size) + window_size - 1
            self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1)) # 13,13
        else:
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))

        self.to_out = nn.Linear(inner_dim, dim)  

    def forward(self, x):
        if self.shifted:
            # x shape is (1, (56,28,14,7), (56,28,14,7), (96,192,384, 768)
            x = self.cyclic_shift(x)
            # x shape is (1, (56,28,14,7), (56,28,14,7), (96,192,384, 768)

        b, n_h, n_w, _, h = *x.shape, self.heads

        qkv = self.to_qkv(x).chunk(3, dim=-1) # basically divided into 3 for q, k , v
        nw_h = n_h // self.window_size  # 56/7 = 8 , (8,4, 2, 1 )
        nw_w = n_w // self.window_size # (8,4, 2, 1 )

        q, k, v = map(
            lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
                                h=h, w_h=self.window_size, w_w=self.window_size), qkv)
        # shape of q, k ,v is 
        # (1, h=(3,6,12,24), (nw_h*nw_w)= (64,16,4,1), (w_h*w_w)=49, d = 32)
        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale
        # (1, h=(3,6,12,24), (w)= (64,16,4,1), (i,j)=49)

        if self.relative_pos_embedding:
            dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
        else:
            dots += self.pos_embedding

        if self.shifted:
            dots[:, :, -nw_w:] += self.upper_lower_mask
            dots[:, :, nw_w - 1::nw_w] += self.left_right_mask

        attn = dots.softmax(dim=-1)

        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
        out = self.to_out(out)

        if self.shifted:
            out = self.cyclic_back_shift(out)
        return out

In the case of swinv1 , layer norm is applied before attention and feedforward

In the case of swinv2 . layer norm is applied after attention and  feedforward

<img src="../images/swin_v1_v2.png" alt="alt text" width="500" align="left"/>

In [10]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
    
class PostNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.norm(self.fn(x, **kwargs))
    
#swin v1
class SwinBlock_v1(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim,
                                                                     heads=heads,
                                                                     head_dim=head_dim,
                                                                     shifted=shifted,
                                                                     window_size=window_size,
                                                                     relative_pos_embedding=relative_pos_embedding)))
        self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))

    def forward(self, x):
        x = self.attention_block(x)
        x = self.mlp_block(x)
        return x
    
#swin v2
class SwinBlock_v2(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        self.attention_block = Residual(PostNorm(dim, WindowAttention(dim=dim,
                                                                     heads=heads,
                                                                     head_dim=head_dim,
                                                                     shifted=shifted,
                                                                     window_size=window_size,
                                                                     relative_pos_embedding=relative_pos_embedding)))
        self.mlp_block = Residual(PostNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))

    def forward(self, x):
        x = self.attention_block(x)
        x = self.mlp_block(x)
        return x

# we include cosine attention in swinv2
import torch.nn.functional as F
class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
        """
        dim = hidden_dim = (96,192,384,768)
        heads = num_heads = (3,6,12,24)
        head_dim=32
        
        """
        super().__init__()
        inner_dim = head_dim * heads # (32*3 = 96, 32*6=192, 32*12=384, 32*24=768)

        self.heads = heads
        self.scale = head_dim ** -0.5 # scalling dot product inside the softmax
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted

        self.tau = nn.Parameter(torch.tensor(0.01))
        # for simplicity but in the orginal paper each head and each layer have a different tau
        if self.shifted:
            displacement = window_size // 2
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                             upper_lower=True, left_right=False), requires_grad=False)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                            upper_lower=False, left_right=True), requires_grad=False)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)  
        # dim = (96,192,384,768)
        # inner_dim = head_dim * 3
        if self.relative_pos_embedding:
            self.relative_indices = get_relative_distances(window_size) + window_size - 1
            self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1)) # 13,13
        else:
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))

        self.to_out = nn.Linear(inner_dim, dim)  

    def forward(self, x):
        if self.shifted:
            # x shape is (1, (56,28,14,7), (56,28,14,7), (96,192,384, 768)
            x = self.cyclic_shift(x)
            # x shape is (1, (56,28,14,7), (56,28,14,7), (96,192,384, 768)

        b, n_h, n_w, _, h = *x.shape, self.heads

        qkv = self.to_qkv(x).chunk(3, dim=-1) # basically divided into 3 for q, k , v
        nw_h = n_h // self.window_size  # 56/7 = 8 , (8,4, 2, 1 )
        nw_w = n_w // self.window_size # (8,4, 2, 1 )

        q, k, v = map(
            lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
                                h=h, w_h=self.window_size, w_w=self.window_size), qkv)
        # shape of q, k ,v is 
        # (1, h=(3,6,12,24), (nw_h*nw_w)= (64,16,4,1), (w_h*w_w)=49, d = 32)
        # First normalizing q and k with respect to each row
        q = F.normalize(q, p=2, dim=-1)
        k = F.normalize(k, p=2, dim=-1)

        # Cosine Similarity
        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) / self.tau
        # b=batch_size, h=heads (3, 6, 12, 24), w=width (64, 16, 4, 1), i=j=49

        #dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale
        # (1, h=(3,6,12,24), (w)= (64,16,4,1), (i,j)=49)

        if self.relative_pos_embedding:
            dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
        else:
            dots += self.pos_embedding

        if self.shifted:
            dots[:, :, -nw_w:] += self.upper_lower_mask
            dots[:, :, nw_w - 1::nw_w] += self.left_right_mask

        attn = dots.softmax(dim=-1)

        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
        out = self.to_out(out)

        if self.shifted:
            out = self.cyclic_back_shift(out)
        return out

In [13]:
# how cosine similariy works for q, k 
import torch
q = torch.randn(1, 64, 49, 5)
print(q[0][0][0])
q = F.normalize(q, p=2, dim=-1) # v/|v|
print(q[0][0][0])
#magnitude should be 1 since its normalise
print(torch.norm(q[0][0][0], p=2))

tensor([ 2.2792, -0.7208, -1.3164,  1.9773,  2.3340])
tensor([ 0.5560, -0.1758, -0.3211,  0.4823,  0.5694])
tensor(1.)


Lets take the first block and see it 
so 
stage1 = StageModule(3,96,layers=2,downscaling_factor=4,num_heads=3,window_size=7,relative_pos_embedding=True)

so there would be 2 layers , one of them the normal MSA and other other one would be SH-MSA

the output would be after the patch merging and then sending it to swin block 

In [14]:
from modules.swin_easy import SwinBlock, PatchMerging
image = read_image()
patch_module =  PatchMerging(3,96,4)
out = patch_module(image)
print(out.shape)

image shape is  torch.Size([1, 3, 224, 224])
torch.Size([1, 56, 56, 96])


In [15]:
hidden_dimension = 96
num_heads = 3
head_dim = 32
window_size = 7
relative_pos_embedding = True

msa_module = SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding)

shmsa_module = SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding)

In [16]:
msa_out = msa_module(out)
print(msa_out.shape)
shmsa_out = shmsa_module(msa_out)
print(shmsa_out.shape)

torch.Size([1, 56, 56, 96])
torch.Size([1, 56, 56, 96])


To understand torch.roll lets consider a (1,81,81) and shift it and see how it works


In [4]:
import torch
# Create a sample tensor
x = torch.linspace(1,81,81).view(9,9)
# Roll the input tensor along the specified dimensions
shifted_x = torch.roll(x, shifts=(-1, -1), dims=(0, 1))

print("Original tensor:")
print(x)
print("\nShifted tensor:")
print(shifted_x)

Original tensor:
tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18.],
        [19., 20., 21., 22., 23., 24., 25., 26., 27.],
        [28., 29., 30., 31., 32., 33., 34., 35., 36.],
        [37., 38., 39., 40., 41., 42., 43., 44., 45.],
        [46., 47., 48., 49., 50., 51., 52., 53., 54.],
        [55., 56., 57., 58., 59., 60., 61., 62., 63.],
        [64., 65., 66., 67., 68., 69., 70., 71., 72.],
        [73., 74., 75., 76., 77., 78., 79., 80., 81.]])

Shifted tensor:
tensor([[11., 12., 13., 14., 15., 16., 17., 18., 10.],
        [20., 21., 22., 23., 24., 25., 26., 27., 19.],
        [29., 30., 31., 32., 33., 34., 35., 36., 28.],
        [38., 39., 40., 41., 42., 43., 44., 45., 37.],
        [47., 48., 49., 50., 51., 52., 53., 54., 46.],
        [56., 57., 58., 59., 60., 61., 62., 63., 55.],
        [65., 66., 67., 68., 69., 70., 71., 72., 64.],
        [74., 75., 76., 77., 78., 79., 80., 81., 73.],
        [ 2.,  3.,  4.,  5.,  

Understanding Masking

In [15]:
def create_mask(window_size, displacement, upper_lower, left_right):
    mask = torch.zeros(window_size ** 2, window_size ** 2)

    if upper_lower:
        mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
        mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

    if left_right:
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
        mask[:, -displacement:, :, :-displacement] = float('-inf')
        mask[:, :-displacement, :, -displacement:] = float('-inf')
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

    return mask

In [16]:
window_size = 3 
displacement = 1
upper_lower=True
left_right=False
create_mask(window_size, displacement, upper_lower, left_right)

tensor([[0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.]])

In [17]:
window_size = 3 
displacement = 1
upper_lower=False
left_right=True
create_mask(window_size, displacement, upper_lower, left_right)

tensor([[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.]])

In [33]:
upper_lower=True
left_right=True
create_mask(window_size, displacement, upper_lower, left_right)

tensor([[0., 0., -inf, 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, 0., 0., -inf, -inf, -inf, -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, -inf],
        [0., 0., -inf, 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, 0., 0., -inf, -inf, -inf, -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0.]])

## Positional Embeddings

In [60]:
#how the relative positiional encoding works

import torch.nn as nn
import torch


window_size = (2,2)
num_heads = 3
relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
print(relative_coords)
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
print(relative_coords.sum(-1))
"""
This step adjusts the relative height coordinates to ensure they are non-negative. By adding window_size[0] - 1 to the height component (index 0 in the last dimension), 
the minimum possible value is shifted to 0. This adjustment is necessary because the relative positions can be negative (e.g., a token above another token), 
and we want to map these to positive indices
"""
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0

print(relative_coords.sum(-1))
"""
This step adjusts the relative height coordinates to ensure they are non-negative. By adding window_size[0] - 1 to the height component (index 0 in the last dimension), 
the minimum possible value is shifted to 0. This adjustment is necessary because the relative positions can be negative (e.g., a token above another token), 
and we want to map these to positive indices
"""
relative_coords[:, :, 1] += window_size[1] - 1
print(relative_coords.sum(-1))
"""
The adjustment of the height component is scaled by 2 * window_size[1] - 1 to ensure that when we sum the height and width adjustments, each unique (height, width) pair maps to a unique index.
This scaling is crucial because it spreads out the indices for the height differences across a larger range than the width differences, ensuring no two different (height, width) 
pairs end up with the same sum.
"""
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
print(relative_position_index)

tensor([[[ 0,  0, -1, -1],
         [ 0,  0, -1, -1],
         [ 1,  1,  0,  0],
         [ 1,  1,  0,  0]],

        [[ 0, -1,  0, -1],
         [ 1,  0,  1,  0],
         [ 0, -1,  0, -1],
         [ 1,  0,  1,  0]]])
tensor([[ 0, -1, -1, -2],
        [ 1,  0,  0, -1],
        [ 1,  0,  0, -1],
        [ 2,  1,  1,  0]])
tensor([[ 1,  0,  0, -1],
        [ 2,  1,  1,  0],
        [ 2,  1,  1,  0],
        [ 3,  2,  2,  1]])
tensor([[2, 1, 1, 0],
        [3, 2, 2, 1],
        [3, 2, 2, 1],
        [4, 3, 3, 2]])
tensor([[4, 3, 1, 0],
        [5, 4, 2, 1],
        [7, 6, 4, 3],
        [8, 7, 5, 4]])


In [61]:
relative_position_index.view(-1)

tensor([4, 3, 1, 0, 5, 4, 2, 1, 7, 6, 4, 3, 8, 7, 5, 4])

In [62]:
#this would be a trainable parameter and relative_position_index helps us in getting the relevant bias terms for each window
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)].view(4, 4, -1)

In [64]:
relative_position_bias

tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]], grad_fn=<ViewBackward0>)

In [66]:
get_relative_distances(3).shape

torch.Size([9, 9, 2])

In [92]:
#for swinv2
import torch.nn as nn
import torch

class CPB_MLP(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(2, 512, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_heads, bias=False)
        )
        
    def forward(self, x):
        return self.layers(x)

window_size = (2,2)
num_heads = 3
relative_coords_table = torch.stack(torch.meshgrid(
    torch.arange(-(window_size[0]-1), window_size[0]),
    torch.arange(-(window_size[1]-1), window_size[1])
)).permute(1, 2, 0).unsqueeze(0)
print(relative_coords_table)
relative_coords_table[:, :, :, 0] = relative_coords_table[:, :, :, 0]/(window_size[0] - 1)
relative_coords_table[:, :, :, 1] = relative_coords_table[:, :, :, 1]/(window_size[0] - 1)
relative_coords_table *= 8  # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / np.log2(8)
print(relative_coords_table.shape)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
print(relative_coords)
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
print(relative_coords.sum(-1))
"""
This step adjusts the relative height coordinates to ensure they are non-negative. By adding window_size[0] - 1 to the height component (index 0 in the last dimension), 
the minimum possible value is shifted to 0. This adjustment is necessary because the relative positions can be negative (e.g., a token above another token), 
and we want to map these to positive indices
"""
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0

print(relative_coords.sum(-1))
"""
This step adjusts the relative height coordinates to ensure they are non-negative. By adding window_size[0] - 1 to the height component (index 0 in the last dimension), 
the minimum possible value is shifted to 0. This adjustment is necessary because the relative positions can be negative (e.g., a token above another token), 
and we want to map these to positive indices
"""
relative_coords[:, :, 1] += window_size[1] - 1
print(relative_coords.sum(-1))
"""
The adjustment of the height component is scaled by 2 * window_size[1] - 1 to ensure that when we sum the height and width adjustments, each unique (height, width) pair maps to a unique index.
This scaling is crucial because it spreads out the indices for the height differences across a larger range than the width differences, ensuring no two different (height, width) 
pairs end up with the same sum.
"""
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
print(relative_position_index)

tensor([[[[-1, -1],
          [-1,  0],
          [-1,  1]],

         [[ 0, -1],
          [ 0,  0],
          [ 0,  1]],

         [[ 1, -1],
          [ 1,  0],
          [ 1,  1]]]])
torch.Size([1, 3, 3, 2])
tensor([[[ 0,  0, -1, -1],
         [ 0,  0, -1, -1],
         [ 1,  1,  0,  0],
         [ 1,  1,  0,  0]],

        [[ 0, -1,  0, -1],
         [ 1,  0,  1,  0],
         [ 0, -1,  0, -1],
         [ 1,  0,  1,  0]]])
tensor([[ 0, -1, -1, -2],
        [ 1,  0,  0, -1],
        [ 1,  0,  0, -1],
        [ 2,  1,  1,  0]])
tensor([[ 1,  0,  0, -1],
        [ 2,  1,  1,  0],
        [ 2,  1,  1,  0],
        [ 3,  2,  2,  1]])
tensor([[2, 1, 1, 0],
        [3, 2, 2, 1],
        [3, 2, 2, 1],
        [4, 3, 3, 2]])
tensor([[4, 3, 1, 0],
        [5, 4, 2, 1],
        [7, 6, 4, 3],
        [8, 7, 5, 4]])


In [93]:
relative_coords_table , relative_coords_table.shape

(tensor([[[[-1.0566, -1.0566],
           [-1.0566,  0.0000],
           [-1.0566,  1.0566]],
 
          [[ 0.0000, -1.0566],
           [ 0.0000,  0.0000],
           [ 0.0000,  1.0566]],
 
          [[ 1.0566, -1.0566],
           [ 1.0566,  0.0000],
           [ 1.0566,  1.0566]]]]),
 torch.Size([1, 3, 3, 2]))

In [59]:
num_heads = 3
cpb_mlp = CPB_MLP(num_heads)
print(relative_coords_table.shape)
relative_position_bias_table = cpb_mlp(relative_coords_table).view(-1, num_heads) # 9, 3]
print(relative_position_bias_table)
print(relative_position_bias_table.shape)
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)].view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
print(relative_position_bias.shape)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
print(relative_position_bias.shape)
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)

torch.Size([1, 3, 3, 2])
tensor([[ 0.0844, -0.2220,  0.0212],
        [ 0.1006,  0.0455, -0.0113],
        [ 0.3036,  0.1925, -0.0254],
        [ 0.0752, -0.3241,  0.0239],
        [ 0.0620, -0.1128,  0.0072],
        [ 0.2428,  0.0898, -0.0980],
        [ 0.2094, -0.2668,  0.1333],
        [ 0.0194, -0.2149,  0.0088],
        [ 0.0231, -0.0200, -0.0982]], grad_fn=<ViewBackward0>)
torch.Size([9, 3])
torch.Size([4, 4, 3])
torch.Size([3, 4, 4])


In [39]:
relative_position_index.view(-1)

tensor([4, 3, 1, 0, 5, 4, 2, 1, 7, 6, 4, 3, 8, 7, 5, 4])

Lets understadn this 

<b>Logarithmic Transformation and Its Non-linearity</b>

The logarithmic function is inherently non-linear, characterized by its unique rate of growth: it increases rapidly for small positive values but decelerates as the values become larger. This property of the logarithmic function makes it especially useful for processing relative positions in data, allowing for a more nuanced handling of positional relationships. Let's delve into how this non-linearity manifests and its implications for modeling positional relationships.

<b> Understanding Logarithmic Growth</b>

Logarithmic growth is defined by the relationship between a number $x$ (for $x > 0$) and the power to which a base (commonly 2, e, or 10) must be raised to yield $x$. For example, using base 2, $\log_2(2) = 1$ because $2^1 = 2$, and $\log_2(8) = 3$ because $2^3 = 8$. The logarithm of a number thus reflects its order of magnitude relative to the base.

- **Rapid Growth for Small Values**: The logarithmic function grows more quickly for values of $x$ close to zero. In the context of relative positions, this means that small differences (e.g., between positions 1 and 2) are accentuated after the logarithmic transformation, providing "more granularity" to these distances. This fine distinction allows models to discern closely positioned elements with greater precision.

- **Slower Growth for Larger Values**: As $x$ becomes larger, the rate at which its logarithm increases slows down. This aspect of logarithmic scaling means that larger distances (e.g., between positions 100 and 101) are compressed into almost identical values after transformation. Such "compression" groups distant positions together, allowing the model to treat them as nearly equivalent and to prioritize the analysis of closer positional interactions.

<b>  Implications for Positional Relationships in NLP and CV</b>

In many domains, such as NLP and CV, the interactions between elements that are close to each other often have more significance than interactions between elements that are far apart. For instance, in textual data, the context and meaning of a word are more directly influenced by its immediate neighbors than by words in a distant sentence. Similarly, in images, pixels are more strongly related to adjacent pixels than to those far across the image.

Applying a logarithmic transformation to relative positions allows models to mimic this distribution of significance, emphasizing the importance of nearer interactions over distant ones. By doing so, computational resources can be focused more efficiently on the most meaningful relationships in the data, potentially enhancing the model's ability to learn and make predictions about the structure and content of the input.

In summary, the logarithmic function, through its non-linear scaling, provides a mechanism for models to discriminate finely among close relationships while consolidating distant ones, aligning model processing more closely with the inherent structure of real-world data.


In [50]:
window_size = (2,2)
num_heads = 3
relative_coords_table = torch.stack(torch.meshgrid(
    torch.arange(-(window_size[0]-1), window_size[0]),
    torch.arange(-(window_size[1]-1), window_size[1])
)).permute(1, 2, 0).unsqueeze(0)
print(relative_coords_table)
relative_coords_table[:, :, :, 0] = relative_coords_table[:, :, :, 0]/(window_size[0] - 1)
relative_coords_table[:, :, :, 1] = relative_coords_table[:, :, :, 1]/(window_size[0] - 1)
print(relative_coords_table)
relative_coords_table *= 8  # normalize to -8, 8
print(relative_coords_table)
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / np.log2(8)
print(relative_coords_table)

tensor([[[[-1, -1],
          [-1,  0],
          [-1,  1]],

         [[ 0, -1],
          [ 0,  0],
          [ 0,  1]],

         [[ 1, -1],
          [ 1,  0],
          [ 1,  1]]]])
tensor([[[[-1, -1],
          [-1,  0],
          [-1,  1]],

         [[ 0, -1],
          [ 0,  0],
          [ 0,  1]],

         [[ 1, -1],
          [ 1,  0],
          [ 1,  1]]]])
tensor([[[[-8, -8],
          [-8,  0],
          [-8,  8]],

         [[ 0, -8],
          [ 0,  0],
          [ 0,  8]],

         [[ 8, -8],
          [ 8,  0],
          [ 8,  8]]]])
tensor([[[[-1.0566, -1.0566],
          [-1.0566,  0.0000],
          [-1.0566,  1.0566]],

         [[ 0.0000, -1.0566],
          [ 0.0000,  0.0000],
          [ 0.0000,  1.0566]],

         [[ 1.0566, -1.0566],
          [ 1.0566,  0.0000],
          [ 1.0566,  1.0566]]]])


Another way to get based on swin easy

In [82]:
import numpy as np
def get_relative_distances(window_size):
    indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
    distances = indices[None, :, :] - indices[:, None, :]
    return distances

window_size = 2
tmp = get_relative_distances(2)
print(tmp[:,:,0])
print(tmp[:,:,1])
tmp = get_relative_distances(2) + window_size - 1
print(tmp[:,:,0])
print(tmp[:,:,1])

pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
print(pos_embedding)
pos_embedding[tmp[:,:,0], tmp[:,:,1]]

tensor([[ 0,  0,  1,  1],
        [ 0,  0,  1,  1],
        [-1, -1,  0,  0],
        [-1, -1,  0,  0]])
tensor([[ 0,  1,  0,  1],
        [-1,  0, -1,  0],
        [ 0,  1,  0,  1],
        [-1,  0, -1,  0]])
tensor([[1, 1, 2, 2],
        [1, 1, 2, 2],
        [0, 0, 1, 1],
        [0, 0, 1, 1]])
tensor([[1, 2, 1, 2],
        [0, 1, 0, 1],
        [1, 2, 1, 2],
        [0, 1, 0, 1]])
Parameter containing:
tensor([[ 0.1278,  0.8391,  0.1339],
        [ 0.1860, -0.4751,  1.4897],
        [-0.0032,  1.7040,  0.8556]], requires_grad=True)


tensor([[-0.4751,  1.4897,  1.7040,  0.8556],
        [ 0.1860, -0.4751, -0.0032,  1.7040],
        [ 0.8391,  0.1339, -0.4751,  1.4897],
        [ 0.1278,  0.8391,  0.1860, -0.4751]], grad_fn=<IndexBackward0>)

In [83]:
a = torch.tensor([1,2,3])
print('a: ', a)
print('size of a: ', a.size())
a1 = a[None, :]
print('a1: ', a1)
print('size of a1: ', a1.size())
a2 = a[:, None]
print('a2: ', a2)
print('size of a2: ', a2.size())
d = -a1 - a2
print('d: ', d)
print('size of d: ', d.size())

a:  tensor([1, 2, 3])
size of a:  torch.Size([3])
a1:  tensor([[1, 2, 3]])
size of a1:  torch.Size([1, 3])
a2:  tensor([[1],
        [2],
        [3]])
size of a2:  torch.Size([3, 1])
d:  tensor([[-2, -3, -4],
        [-3, -4, -5],
        [-4, -5, -6]])
size of d:  torch.Size([3, 3])


how to change its window size

In [80]:
from transformers import AutoModel,AutoConfig
config = AutoConfig.from_pretrained(f'microsoft/swinv2-base-patch4-window8-256')

In [86]:
config.window_size = 16

In [88]:
config.window_size

16

In [89]:

model = AutoModel.from_pretrained(f'microsoft/swinv2-base-patch4-window8-256',config=config)

For window size 2x2 u need 9 positonal bias <br>
For window size 3x3 u need 25 positionl bias  <br>
then we need to do bicubic interpolation but in swinv2 it handles it itself since the mlp layer takes care of it

In [105]:
import torch
import torch.nn.functional as F

# Example input
relative_coords_table = torch.randn(3, 3, 2)

# Adjust the data to fit the expected input shape of interpolate(), i.e., (batch_size, channels, height, width)
# Here, we treat the last dimension as channels and add a fake batch dimension
relative_coords_table = relative_coords_table.permute(2, 0, 1).unsqueeze(0)

# Apply bicubic interpolation
# target_size is (5, 5) for the spatial dimensions (height and width)
target_size = (5, 5)
interpolated = F.interpolate(relative_coords_table, size=target_size, mode='bicubic', align_corners=True)

# Remove the fake batch dimension and put the channels back to the last dimension
interpolated = interpolated.squeeze(0).permute(1, 2, 0)

print(interpolated.shape)  # Should be torch.Size([5, 5, 2])


torch.Size([5, 5, 2])


In [106]:
relative_coords_table

tensor([[[[ 0.8989, -0.9325, -0.4685],
          [ 0.9616,  0.9086, -0.9159],
          [ 1.3957,  1.4294,  0.1827]],

         [[ 0.1387, -0.7123, -0.2820],
          [-0.2480,  0.7367,  0.4269],
          [-0.5558, -0.0121,  1.1142]]]])

In [107]:
interpolated.permute(2, 0, 1).unsqueeze(0)

tensor([[[[ 0.8989, -0.0603, -0.9325, -0.8722, -0.4685],
          [ 0.8896,  0.4832, -0.0608, -0.5171, -0.7952],
          [ 0.9616,  1.1061,  0.9086, -0.0086, -0.9159],
          [ 1.1845,  1.4271,  1.3416,  0.4813, -0.4085],
          [ 1.3957,  1.5294,  1.4294,  0.8092,  0.1827]],

         [[ 0.1387, -0.3271, -0.7123, -0.5769, -0.2820],
          [-0.0258,  0.0353,  0.0824,  0.0553,  0.0080],
          [-0.2480,  0.2734,  0.7367,  0.6741,  0.4269],
          [-0.4382, -0.0018,  0.4981,  0.7553,  0.8370],
          [-0.5558, -0.3895, -0.0121,  0.6021,  1.1142]]]])