In [221]:
import torch 
from torch import nn, einsum
import numpy as np 
from einops import rearrange
import torch.nn.functional as f
import plotly.graph_objs as go


### Patch Merging 

Custom Code for the Patch Merging

In [222]:

# only for STB1, they have used "Patch Merging" and "Linear Embedding" and have not used "Patch Merging"
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        # patch merging is just used for downscaling the feature set ig 
        
        super().__init__()
        # just changing the feature size without overlaping and kernel size is the window size 
        # kernel should not overlap with each other 
        self.patch_merge = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels,
            kernel_size=downscaling_factor,  # this is fine. Mostly the size of the window
            stride=downscaling_factor, # no window overlappting, so the stride should shift accordingly 
            # so, we would have more control to change the saliency feature of the iamge  
            padding=0
        )
    
    def forward(self, x):
        x = self.patch_merge(x).permute(0, 2, 3, 1) # shift the channel to the last
        return x

In [223]:
torch.manual_seed(0)

B, H, W, C = 1, 2, 2, 3  # 4 tokens across 3 channels 
input = torch.randn(B, H, W, C) * 100 # for better illustration
print("Input: ", input)
layer_norm = nn.LayerNorm(C)  # give no of channels as the dim 
output = layer_norm(input)
print("Output: ", output)

# each token of the 4 tokens normalized with respect to iteself. 

Input:  tensor([[[[ 154.0996,  -29.3429, -217.8789],
          [  56.8431, -108.4522, -139.8595]],

         [[  40.3347,   83.8026,  -71.9258],
          [ -40.3344,  -59.6635,   18.2036]]]])
Output:  tensor([[[[ 1.2191,  0.0112, -1.2303],
          [ 1.3985, -0.5173, -0.8813]],

         [[ 0.3495,  1.0120, -1.3615],
          [-0.3948, -0.9787,  1.3735]]]], grad_fn=<NativeLayerNormBackward0>)


In [224]:
"""
Normalize each token with respect to iteself
"""
class PreNorm(nn.Module):  # Layer Normalization. 
    # normalize the input before sending to the WindowAttention or FeedForward 
    def __init__(self, dim, fn):   # fn - WindowAttention() as the function input or block 
        super().__init__()
        self.norm = nn.LayerNorm(dim)  # number of channels as the input, normalizing across channels
        self.fn = fn
    
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs) # normalize and send to the block 

In [225]:
class PostNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)  # number of channels as the input, normalizing across channels
        self.fn = fn
    def forward(self, x, **kwargs):
        # apply normalization after the processing 
        return self.norm(self.fn(x, **kwargs)) 

In [226]:
class Residual(nn.Module):
    # fn -- PreNorm(WindowAttention())   this is the function input 
    def __init__(self, fn):  # it is the function or the block 
        super().__init__()
        self.fn = fn
        
    # X - input 
    def forward(self, x, **kwargs):  # **kwargs -- params 
        return self.fn(x, **kwargs) + x   # add the input of the block to the output of the block 

In [227]:
x = torch.linspace(1, 81, 81).view(9, 9)  # 9 * 9 is 81, so, we will get a sqaure tensor 
print(torch.linspace(1, 81, 81))
print(x)

tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28.,
        29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42.,
        43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56.,
        57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70.,
        71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 81.])
tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18.],
        [19., 20., 21., 22., 23., 24., 25., 26., 27.],
        [28., 29., 30., 31., 32., 33., 34., 35., 36.],
        [37., 38., 39., 40., 41., 42., 43., 44., 45.],
        [46., 47., 48., 49., 50., 51., 52., 53., 54.],
        [55., 56., 57., 58., 59., 60., 61., 62., 63.],
        [64., 65., 66., 67., 68., 69., 70., 71., 72.],
        [73., 74., 75., 76., 77., 78., 79., 80., 81.]])


In [228]:
y = torch.roll(input=x, shifts=(1, 1), dims=(0, 1))
print(y)

tensor([[81., 73., 74., 75., 76., 77., 78., 79., 80.],
        [ 9.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.],
        [18., 10., 11., 12., 13., 14., 15., 16., 17.],
        [27., 19., 20., 21., 22., 23., 24., 25., 26.],
        [36., 28., 29., 30., 31., 32., 33., 34., 35.],
        [45., 37., 38., 39., 40., 41., 42., 43., 44.],
        [54., 46., 47., 48., 49., 50., 51., 52., 53.],
        [63., 55., 56., 57., 58., 59., 60., 61., 62.],
        [72., 64., 65., 66., 67., 68., 69., 70., 71.]])


In [229]:
y = torch.roll(input=x, shifts=(-1, -1), dims=(0, 1))  # dims is h and w 
print(y)

# so, last row and last col window will have unrelated pixels as it came from the other side 
# we don't want this
# so, we apply masking 

tensor([[11., 12., 13., 14., 15., 16., 17., 18., 10.],
        [20., 21., 22., 23., 24., 25., 26., 27., 19.],
        [29., 30., 31., 32., 33., 34., 35., 36., 28.],
        [38., 39., 40., 41., 42., 43., 44., 45., 37.],
        [47., 48., 49., 50., 51., 52., 53., 54., 46.],
        [56., 57., 58., 59., 60., 61., 62., 63., 55.],
        [65., 66., 67., 68., 69., 70., 71., 72., 64.],
        [74., 75., 76., 77., 78., 79., 80., 81., 73.],
        [ 2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.,  1.]])


In [230]:
# we need to mask the last row and last column window 
# 8 -> 4 -> 2 -> 1
class CyclicShift(nn.Module):
    def __init__(self, displacement):
        super().__init__()
        self.displacement = displacement
    
    # as this is (-displacement), it would be like shift left and up. 
    def forward(self, x):
        return torch.roll(x,shifts=(self.displacement, self.displacement), dims=(1, 2)) # roll to the right and the down. 
    # we have 3 dims, that is h x w x c    # but, this will be changed to c x h x w 
    # so, we are shifting dims in the 1 and 2, understood clearly 

In [231]:
# this is for each of the patches (windows) 
# we are working on pixels 
def create_mask(  # to handle cycleshifted patches 
    window_size, 
    displacement, 
    upper_lower, 
    left_right
):
    # make a matrix mask for the left and right based on the condition
    mask = torch.zeros(window_size ** 2, window_size ** 2)   # (49, 49)   # window size is 7 so, it would be 49 x 49
    # print('Original mask: \n', mask)
    
    if upper_lower:  # displacement = window_size // 2
        # [h, w]
        # from down to up from the lower, 
        mask[-displacement * window_size:, :-displacement * window_size] = float('-inf') # down left section 
        mask[:-displacement * window_size:, -displacement * window_size:] = float('-inf') # up right section
    
    if left_right:
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
        mask[:, -displacement:, :, :-displacement] = float('-inf')
        mask[:, :-displacement, :, -displacement] = float('-inf')
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')
    
    # print("Processed Mask: \n", mask)
    return mask    

Example

In [232]:
window_size = 3
displacement = window_size // 2

mask = torch.zeros(window_size ** 2, window_size ** 2)
print(mask)
print(displacement)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.]])
1


In [233]:
upper_lower = True

mask = torch.zeros(window_size ** 2, window_size ** 2)

if upper_lower:  
    # down left section 
    # displacement * window_size
    mask[-displacement * window_size:, :-displacement * window_size] = float(1) 
    mask[:-displacement * window_size:, -displacement * window_size:] = float(1) 

print(mask)


tensor([[0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0.]])


In [234]:
left_right = True

mask = torch.zeros(window_size ** 2, window_size ** 2)

if left_right:
    mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
    mask[:, -displacement:, :, :-displacement] = float(1)
    mask[:, :-displacement, :, -displacement] = float(1)
    mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

print(mask)

tensor([[0., 0., 1., 0., 0., 1., 0., 0., 1.],
        [0., 0., 1., 0., 0., 1., 0., 0., 1.],
        [1., 1., 0., 1., 1., 0., 1., 1., 0.],
        [0., 0., 1., 0., 0., 1., 0., 0., 1.],
        [0., 0., 1., 0., 0., 1., 0., 0., 1.],
        [1., 1., 0., 1., 1., 0., 1., 1., 0.],
        [0., 0., 1., 0., 0., 1., 0., 0., 1.],
        [0., 0., 1., 0., 0., 1., 0., 0., 1.],
        [1., 1., 0., 1., 1., 0., 1., 1., 0.]])


In [235]:
pattern = mask[:3, :3]
print(pattern)  # this pattern keeps on repeating 

tensor([[0., 0., 1.],
        [0., 0., 1.],
        [1., 1., 0.]])


In [236]:
np.set_printoptions(precision=2, suppress=True)

In [237]:
"""Abs Pos Embedding"""

window_size = 3 # 9 elements
num_of_params = 81  # size of the first stage,  # 9 * 9 = 81 
# relate each param to one other 
pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2), requires_grad=False)

# print(torch.tensor(pos_embedding).apply_(lambda x: float(f"{x:.2f}")))

print(torch.tensor(pos_embedding).clone().detach().numpy())

# print(torch.round(torch.tensor(pos_embedding) * 100) / 100)

[[ 0.47 -0.16  1.44  0.27  0.17  0.87 -0.14 -0.11  0.93]
 [ 1.26  2.    0.05  0.62 -0.41 -0.84 -2.32 -0.22 -0.74]
 [ 0.56  0.26 -0.17 -0.68  0.94  0.49  1.2   0.08 -1.2 ]
 [-0.   -0.52 -0.31 -1.58  1.71  0.21 -0.45 -0.57 -0.56]
 [ 0.59  1.54  0.51 -0.59 -1.33  0.19 -0.07 -0.49 -1.5 ]
 [-0.19  0.45  1.33  1.51  2.08  1.71  2.38 -1.13 -0.32]
 [-1.09 -0.09  0.33 -0.76 -1.6   0.02 -0.75  0.19  0.62]
 [ 0.64 -0.    1.11  0.28  0.43 -0.8  -1.3  -0.75 -1.31]
 [-0.22 -0.33 -0.43  0.23  0.8  -0.18 -0.37 -1.21 -0.62]]



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



#### Relative Positional Embeddings


So, in this case, in total, it would be 25 x 25 param for the RPE. but in abs position, i had 81 x 81 as each element should look to other 8 elements including itself. so, it would be 9. so, for 9 elements, it would be 81 x 81 elements as 9 elements on each row and column 
  

In the SWIN transformer case, 

window_size of 7 would result in parameter size of 13 x 13. 

It means that, each window will look 6 elements in the right and 6 elements in the left and 1 itself. so, 2M-1 where M is the window_size  and same goes for up and down so another 2M-1 

so, the resulting would be (2M-1)x(2M-1) -- this is just the parameter and no the matrix we need ofcourse 



In [238]:
window_size = 3
# 9 x 9 parameters
"""It has 2 channels"""

# (2 * 3 -1) x (2 * 3 - 1) = 5 x 5 = 25 params, but 2 channels 
pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1), requires_grad=False)
print(pos_embedding)

# for a window size = 3, so, the matrix for position embedding would be 9 x 9
# the set of 3 x 3 in 9 x 9 matrix will be repeating itself throughout the matrix 
# we count only the original, so, in each 3 x 3 matrix, we have 5 original elements 
# and across the 9 x 9 matrix, we would have 5 set os 5 elements 
# so 25 parameters will for relative position  

Parameter containing:
tensor([[ 1.0367, -0.6037, -1.2788,  0.0930, -0.6661],
        [ 0.6080, -0.7300, -0.8834,  0.6596,  0.2440],
        [ 1.1646,  0.2886,  0.3866, -0.2011, -0.1179],
        [-0.8294, -1.4073, -1.9003,  0.1307, -0.7043],
        [ 0.3147,  0.1574,  0.3854,  0.5737,  0.9979]])


In [239]:
arr = np.array([1, 2, 3, 4])
print(arr.shape)
arr = arr[:, None]
print(arr.shape)
arr = arr[None, :]
print(arr.shape)

# None adds the new dim at the specified position 

(4,)
(4, 1)
(1, 4, 1)


In [240]:
window_size = 3
indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))  # it will have 3*3 = 9 indices 
print(indices)
print(indices.shape)

distances = indices[None, :, :] - indices[:, None, :]
print(distances.shape)

tensor([[0, 0],
        [0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [1, 2],
        [2, 0],
        [2, 1],
        [2, 2]])
torch.Size([9, 2])
torch.Size([9, 9, 2])


In [241]:
x, y, z = np.indices(distances.shape)
scatter = go.Scatter3d(x=x.flatten(), y=y.flatten(), z=z.flatten(),
                       mode='markers',
                       marker=dict(size=5, color=distances.flatten(), colorscale='Viridis'))

# Set up the layout
layout = go.Layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
                   margin=dict(l=0, r=0, b=0, t=0))

fig = go.Figure(data=[scatter], layout=layout)
fig.show()


In [242]:
distances[:, :, 0]  # so, 5 different sub-matrixes you see that you ... 

tensor([[ 0,  0,  0,  1,  1,  1,  2,  2,  2],
        [ 0,  0,  0,  1,  1,  1,  2,  2,  2],
        [ 0,  0,  0,  1,  1,  1,  2,  2,  2],
        [-1, -1, -1,  0,  0,  0,  1,  1,  1],
        [-1, -1, -1,  0,  0,  0,  1,  1,  1],
        [-1, -1, -1,  0,  0,  0,  1,  1,  1],
        [-2, -2, -2, -1, -1, -1,  0,  0,  0],
        [-2, -2, -2, -1, -1, -1,  0,  0,  0],
        [-2, -2, -2, -1, -1, -1,  0,  0,  0]])

In [243]:
distances[:3, :3, 1]  # this keeps on repeating itself for each of the 3 x 3 sub matrix in the whole matrix 

tensor([[ 0,  1,  2],
        [-1,  0,  1],
        [-2, -1,  0]])

In [244]:
distances[:, :, 1]  # each matrix of 3 x 3 is repeating itself. 
# if you notice, that matrix has 5 unique elements
# so, we have 5 x 5 (from the 0th channel and the 1st channel), so, we have 25 params

tensor([[ 0,  1,  2,  0,  1,  2,  0,  1,  2],
        [-1,  0,  1, -1,  0,  1, -1,  0,  1],
        [-2, -1,  0, -2, -1,  0, -2, -1,  0],
        [ 0,  1,  2,  0,  1,  2,  0,  1,  2],
        [-1,  0,  1, -1,  0,  1, -1,  0,  1],
        [-2, -1,  0, -2, -1,  0, -2, -1,  0],
        [ 0,  1,  2,  0,  1,  2,  0,  1,  2],
        [-1,  0,  1, -1,  0,  1, -1,  0,  1],
        [-2, -1,  0, -2, -1,  0, -2, -1,  0]])

In [245]:
indices[None, :, :].shape , indices[:, None, :].shape

# basically, we are making a row of the indices and then the columns of the indices 

(torch.Size([1, 9, 2]), torch.Size([9, 1, 2]))

In [246]:
indices[:, 0]

tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])

In [247]:
indices[:, 1]

tensor([0, 1, 2, 0, 1, 2, 0, 1, 2])

In [248]:
dummy = indices[None, :, :] - indices[:, None, :] 
print(dummy.shape)
print(dummy[:, :, 0], "\n\n", dummy[:, :, 1])

torch.Size([9, 9, 2])
tensor([[ 0,  0,  0,  1,  1,  1,  2,  2,  2],
        [ 0,  0,  0,  1,  1,  1,  2,  2,  2],
        [ 0,  0,  0,  1,  1,  1,  2,  2,  2],
        [-1, -1, -1,  0,  0,  0,  1,  1,  1],
        [-1, -1, -1,  0,  0,  0,  1,  1,  1],
        [-1, -1, -1,  0,  0,  0,  1,  1,  1],
        [-2, -2, -2, -1, -1, -1,  0,  0,  0],
        [-2, -2, -2, -1, -1, -1,  0,  0,  0],
        [-2, -2, -2, -1, -1, -1,  0,  0,  0]]) 

 tensor([[ 0,  1,  2,  0,  1,  2,  0,  1,  2],
        [-1,  0,  1, -1,  0,  1, -1,  0,  1],
        [-2, -1,  0, -2, -1,  0, -2, -1,  0],
        [ 0,  1,  2,  0,  1,  2,  0,  1,  2],
        [-1,  0,  1, -1,  0,  1, -1,  0,  1],
        [-2, -1,  0, -2, -1,  0, -2, -1,  0],
        [ 0,  1,  2,  0,  1,  2,  0,  1,  2],
        [-1,  0,  1, -1,  0,  1, -1,  0,  1],
        [-2, -1,  0, -2, -1,  0, -2, -1,  0]])


In [249]:
a = torch.tensor([1, 2, 3])
print('a : \n', a, a.shape)
a1 = a[None, :]
print('a1: \n', a1, a1.shape)  
a2 = a[:, None]
print('a2: \n', a2, a2.shape)
d = a1 - a2 
print('d : \n', d, d.shape)


a : 
 tensor([1, 2, 3]) torch.Size([3])
a1: 
 tensor([[1, 2, 3]]) torch.Size([1, 3])
a2: 
 tensor([[1],
        [2],
        [3]]) torch.Size([3, 1])
d : 
 tensor([[ 0,  1,  2],
        [-1,  0,  1],
        [-2, -1,  0]]) torch.Size([3, 3])


#### Array boardcasting if operations on arrays with different dimensions

A sensible way of doing element-wise operations on arrays of different (but compatible) shapes 

For the given tensor or array, broadcasting can be done with any number of dimesions with dim=1 it has. Simply, shape of 3 could be 1x3 or 1x1x3 or 1x3x1 or so on

1 can match with any dimensions given

If shorter, prepend 1 with the dimensions. only prepending. 

- when matching the exact dim, just atleast one has to be 1 to make a match, and you can prepend if you want. 
- the results, would be the number other than the dim 1 


In [250]:
def get_relative_distances(window_size):
    # this is creating all the indices i need for the rows and columns. 
    indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)])) # indices of each and every elements of the matrix 
    # print("Indices for the Relative Distance: \n", indices.size())
    
    
    """Array boardcasting"""
    distances = indices[None, :, :] - indices[:, None, :]
    
    return distances

In [251]:

class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        inner_dim = head_dim * heads

        self.heads = heads
        self.scale = head_dim ** -0.5
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted

        if self.shifted:
            displacement = window_size // 2
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                             upper_lower=True, left_right=False), requires_grad=False)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                            upper_lower=False, left_right=True), requires_grad=False)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        if self.relative_pos_embedding:
            self.relative_indices = get_relative_distances(window_size) + window_size - 1
            self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
        else:
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))

        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        if self.shifted:
            x = self.cyclic_shift(x)

        b, n_h, n_w, _, h = *x.shape, self.heads

        qkv = self.to_qkv(x).chunk(3, dim=-1)
        nw_h = n_h // self.window_size
        nw_w = n_w // self.window_size

        q, k, v = map(
            lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
                                h=h, w_h=self.window_size, w_w=self.window_size), qkv)

        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale

        if self.relative_pos_embedding:
            dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
        else:
            dots += self.pos_embedding

        if self.shifted:
            dots[:, :, -nw_w:] += self.upper_lower_mask
            dots[:, :, nw_w - 1::nw_w] += self.left_right_mask

        attn = dots.softmax(dim=-1)

        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
        out = self.to_out(out)

        if self.shifted:
            out = self.cyclic_back_shift(out)
        return out


In [252]:
class WindowAttention_(nn.Module):
    def __init__(
        self, 
        dim,   # no of channels 
        heads,  # heads = num_heads = (3, 6, 12, 24)
        head_dim, 
        shifted,  # for SW-MSA 
        window_size,   # window size is 7, as the size of the last block is 7 x 7, so, we divide every feature before that in patches of 7 window size 
        relative_pos_embedding,
    ):
        
        """hidden_dim is actually the number of channels technically"""
        # dim = hidden_dim = (96, 192, 384, 768)  # out_channels from each block 
                
        # head_dim = 32 
        
        """
        head * head_dim = channels 
        
        (3 * 32) = 96
        (6 * 32) = 192
        (12 * 32) = 384
        (24 * 32) = 768
        """
        super().__init__()
        
        inner_dim = head_dim * heads   # (channels calculating)
        self.heads = heads 
        
        # for the Attention(Q, K, V) = Softmax(QK^T / sqrt(head_dim) +B)V  # what is B --> constant for regularization
        self.scale = head_dim ** -0.5  # 1 / sqrt(head_dim)
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted
        
        # to make the connections between multiple windows by shifting and padding those 
        if self.shifted:
            # shit all the windows to right and down and pad them by half of the size of the window to make connections between windows 
            # and padd the empty space: 
            """
            2 padding: 
                Naive padding (adding 0)
                Cyclic padding  -- [faster] 
            """
            displacement = window_size // 2  # half of the window size 
            
            # shift them and should be able to shift them back how it ws before 
            """
            With the cyclic shift, the no of padding remains the same, and we don't need any extra computations to process the padded 0 as in the naive solution
            """
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)
            
            # problem of last row and last column would be in problem with this Cyclic Shift 
            # when we shift to right and down. 
            """Windows at last row and last column would be related to unrelated patches from the other side. we don't want this"""
            # create a matrix to mask. 0-no mask 1-mask (-infinity). make a matrix of n**2 x n**2 and then make the masking. 
            
            # masks are NOT learnable params, requires_grad=False
            # Last row matrix 
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                             upper_lower=True, left_right=False), requires_grad=False)
            
            # last column matrix (COMPLEX.. couldn't understand)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,                                                             
                                                            upper_lower=False,  left_right=True), requires_grad=False)
            
        
        # query, key, value. 
        # increase the output by 3
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)  # inner_dim is the channels of the current layer 
        
        # positional embedding 
        """Absolute positional embedding"""
        self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2)) # (49 x 49)
        
        """Relative Positional embedding"""
        if self.relative_pos_embedding:
            self.relative_indices = get_relative_distances(window_size) + window_size - 1  # this will create a big matrix where we will put matrix inside that      
            self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))  # 13 x 13 parameters --- 2 * 7 - 1   
        else:
            # normal positional embedding
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))         
        """ViT paper:
        we divide image into pathces and each patch is itself a token. 
        # where the patch belongs - pos embedding 
        """
        
        self.tau = nn.Parameter(torch.tensor(0.01), requires_grad=True) # learnable parameter. Initialized to 0.01 initialy according to the paper 
        
        # first normalizing q and k with respect to each row 
        
        self.to_out = nn.Linear(inner_dim, dim)  # inner_dim = channels 
        # dim = hidden_dim = channels 
    
    def forward(self, x):
        if self.shifted:
            x = self.cyclic_shift(x)  # cyclic shift won't change the dimensions 
            # we need to use masking technique if we are using cyclic shift.
            # print(x.size())
        b, n_h, n_w, _, h = *x.shape, self.heads  # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
        
        # print(self.to_qkv(x).size()) # (1, (56, 28, 14, 7), (56, 28, 14, 7), (288, 576, 1152, 2304))
        # channels are just changed, h w don't change
        qkv = self.to_qkv(x).chunk(3, dim=-1) # across channels 
        # q k v will have the same size 
        
        # rows and cols of windows 
        # no of windows i have in each stage  
        nw_h = n_h // self.window_size   # 8, 4, 2, 1  for all 4 stages 
        nw_w = n_w // self.window_size
        
        # using rearrage from the einops 
        # head in first dim = 3
        # head_dim = 32 
        q, k, v = map(
            lambda t: rearrange(t,
                                # b, h=#heads, (nw_h*nw_w) = (64, 16, 4, 1), (w_h, w_w)=(7 * 7 = 49), d=32  
                                'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',  # if in brackets, it will multiply 
                                h=h, w_h=self.window_size, w_w=self.window_size),  # decomposition. adding new dim, so epcify
            qkv
        )
        
        # print(q.size()) # k and v would have the same shape 
        # (b=1, h=(3, 6, 12, 24), (nw_h * nw_w)=(74, 16, 4, 1), (w_h * w_w) = 49, d=32)
        
        # QK_t  """ i j -- window size (49 pixels, (7 x 7))"
        # h = #heads (3, 6, 12, 24) 
        # w = (64, 16, 4, 1) # how the pixels are grouped into window_size and grouped 
        # DOT PRODUCT SIMILARITY 
        # dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale  # attention formula 
        # h - no of heads  
        # w - no of windows 
        # i, j - 49 (7 * 7)  #  
        # d = 32 # head_dim 
        
        """Cosine similiarity"""
        # sim(q_i, k_j) = cos(q_i, k_j) / t_ + B_ij      # t_ is the tou which is the learnable paramter 
        # B_ij is the relative positional bias between pixel i and j. 
        
        # eulidean norm 
        # to normalize a vector, we divide each element of the vector by the L2 norm  
        q = f.normalize(q, p=2, dim=-1) # L2norm as p=2 ... ensuring that each vector  has unit length
        k = f.normalize(k, p=2, dim=-1)  # d = 32 # normalize along the d=32 
        
        # cosine similarity 
        # if you multiply v by the reciprocal of its length ||V||_2 you are essentially tranforming the vector to have a magnitude of 1. 
        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) / self.tau   
        
        
        if self.relative_pos_embedding:
            tmp1 = self.relative_indices[:, :, 0] # the 1st channel is taken 
            # tmp2 = self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]] # (49, 49, 2)
            
            
            # indexing with and adding the elements to the "dots", acessing from the  
            dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
        else:
            """Softmax(QK_t / sqrt(d) + B)*V """ # B - pos embedding 
            # pos_embedding - (49, 49)
            # dots - b h w i j
            dots += self.pos_embedding # added to every window (w) of size (49, 49) that is (i, j)

        """We need to add masking to last row and last column as we used cyclic_shift"""
        
        if self.shifted:
            # tmp1 = self.upper_lower_mask
            # tmp2 = self.left_right_mask
            """we are applying to all the windows [3] dim """ 
            # dots - (b, h, w, i, j) # to all the heads and to only the last row and last column windows   
            # no of windows        # only to windows in the last row and last column 
            dots[:, :, -nw_w] += self.upper_lower_mask  # just slice the last row by -nw_w  # nw_w - no of windows ( n_w x n_w )
            # mask only the last column windows 
            dots[:, :, nw_w - 1::nw_w] += self.left_right_mask # add to the last column
            # for the first stage, nw_w = 8, so, nw_w - 1 = 7 
            """[start:stop:step] in slicing""" # so, 7::8 # means do, till the last
            # mapping only the last column.  
        
        attn = dots.softmax(dim=-1) # across all the windows # 
        # (b, h=(3, 6, 12, 24), (nw_h * nw_w)=(64, 16, 4, 1), (w_h*w_w)=49, d=32) where d=head_dim, h=#heads 
        
        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        # rearrage the output as to the original input to the "WindowAttention" as the widnow attention wouldn't change the dim or the feature size, only the PatchMerging would do that
        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
        # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
        
        out = self.to_out(out)  # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
        # if we are in shifted window, we need to cycle back to get the original feature set 
        
        if self.shifted:
            out = self.cyclic_back_shift(out)
        
        return out        
        

##### In the SWIN Transformer v2, they have used Cosine Similarity instead of the Dot product similarity 



In [253]:
# data matrix to index 
p = torch.tensor([[1, 2], 
                  [3, 4]])

print(p.size())

# indices 
r = torch.tensor([[[0, 0], [0,0], [0,0]],
                 [[1,1], [1,1], [1,1]],
                 [[0,1], [0,1], [0,1]]])  # 3, 3, 2
print(r.size())

# indexing the tensor with another tensor 
print(p[r[:, :, 0], r[:, :, 1]])  # great work 

torch.Size([2, 2])
torch.Size([3, 3, 2])
tensor([[1, 1, 1],
        [4, 4, 4],
        [2, 2, 2]])


In [254]:
# for dots and pos_embed adding 
do = torch.randn(1, 1, 2, 5, 5) # 2 windows  
print(do)

po = torch.ones(5, 5) # pos_embedding
print(po)

su = do + po
print(su)

# 1 is added to each of the element and added to all the windows 
# so, pos_embedding is added to every windows ( 7 x 7) and it is done to all the windows in the feature set as the feature set is divided into multiple windows 

#----

tensor([[[[[ 0.3016, -0.1073,  0.9985, -0.4987,  0.7611],
           [ 0.6183, -0.2994, -0.1878, -0.1201,  0.3605],
           [-0.3140, -1.0787,  0.2408, -1.3962,  0.1136],
           [ 1.1047, -1.5616, -0.3546,  1.0811,  0.1315],
           [ 1.5735,  0.7814,  0.9874, -1.4878,  1.4708]],

          [[ 0.2756,  0.6668, -0.9944, -1.1894, -1.1959],
           [ 1.3119, -0.2098,  0.4069,  0.3946, -1.2537],
           [ 0.9868, -0.4947, -1.2830,  0.4386, -0.0107],
           [ 1.3384, -0.2794,  0.2877, -0.0334, -1.0619],
           [-0.1144,  0.1954, -0.7371,  1.7001,  0.3462]]]]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
tensor([[[[[ 1.3016,  0.8927,  1.9985,  0.5013,  1.7611],
           [ 1.6183,  0.7006,  0.8122,  0.8799,  1.3605],
           [ 0.6860, -0.0787,  1.2408, -0.3962,  1.1136],
           [ 2.1047, -0.5616,  0.6454,  2.0811,  1.1315],
           [ 2.5735,  1.7814,  

In [255]:
arr = np.arange(64).reshape(1, 8, 8)  # batch, 8 x8 for stage 1
print(arr)

result = arr[:, :, 7::8]  # all the rows, and last column 
print(result)

[[[ 0  1  2  3  4  5  6  7]
  [ 8  9 10 11 12 13 14 15]
  [16 17 18 19 20 21 22 23]
  [24 25 26 27 28 29 30 31]
  [32 33 34 35 36 37 38 39]
  [40 41 42 43 44 45 46 47]
  [48 49 50 51 52 53 54 55]
  [56 57 58 59 60 61 62 63]]]
[[[ 7]
  [15]
  [23]
  [31]
  [39]
  [47]
  [55]
  [63]]]


In [256]:
do = torch.randn(1, 1, 2, 5, 5)  # 2 windows of size 5 x 5 
print(do)

mo = torch.ones(5, 5) # to add with windows 
print(mo)

do[:, :, 1] += mo  # applies to only the last window 
print(do) # pos embedding is added only to the 2nd column 

tensor([[[[[-0.1448,  0.6376, -0.2813, -1.3299, -0.6538],
           [ 1.7198, -0.9610, -0.6375, -0.8870,  0.8388],
           [ 1.1529, -1.7611, -1.1070, -1.7174,  1.5346],
           [-0.0032,  1.4403, -0.1106,  0.5769, -0.1692],
           [ 1.1887, -0.1575, -0.0455,  0.6485, -0.8707]],

          [[ 0.1447,  1.9029,  0.3904,  0.0331, -1.0234],
           [ 0.7335,  1.1177,  0.5851, -1.1560, -0.5354],
           [-0.8637, -0.9069, -0.5918,  0.1508, -1.0411],
           [-0.7205, -2.2148,  0.9403, -1.1470,  0.7928],
           [ 0.0832,  0.4228, -1.8687, -1.1057,  0.1437]]]]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
tensor([[[[[-0.1448,  0.6376, -0.2813, -1.3299, -0.6538],
           [ 1.7198, -0.9610, -0.6375, -0.8870,  0.8388],
           [ 1.1529, -1.7611, -1.1070, -1.7174,  1.5346],
           [-0.0032,  1.4403, -0.1106,  0.5769, -0.1692],
           [ 1.1887, -0.1575, -

In [257]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        
        # hidden_dim --- 4 times the size of the input (channels)
        self.net = nn.Sequential(
            # linear layers 
            nn.Linear(dim, hidden_dim),  
            # activation function  
            nn.GELU(),  # Gaussian Error Linear Units
            nn.Linear(hidden_dim, dim),  # process with hidden_dim and bring back to the same dim after processing 
        )
    
    def forward(self, x):
        return self.net(x)  # FFC 2 layer network

##### Cosine similarity 

Doing the normalization stuff

In [258]:
a = torch.tensor([[[1., 2], [2, 2], [1, 1]],
                  [[1, 4], [3, 6], [5, 6]],
                  [[9, 1], [2, 8], [9, 7]]])
print(a) 
print(a.size())
print(a[:, :, 0], a[:, :, 1])
# p = exponent value of the norm function 
a = f.normalize(a, p=2, dim=-1)
print(a)
print(a.size())


tensor([[[1., 2.],
         [2., 2.],
         [1., 1.]],

        [[1., 4.],
         [3., 6.],
         [5., 6.]],

        [[9., 1.],
         [2., 8.],
         [9., 7.]]])
torch.Size([3, 3, 2])
tensor([[1., 2., 1.],
        [1., 3., 5.],
        [9., 2., 9.]]) tensor([[2., 2., 1.],
        [4., 6., 6.],
        [1., 8., 7.]])
tensor([[[0.4472, 0.8944],
         [0.7071, 0.7071],
         [0.7071, 0.7071]],

        [[0.2425, 0.9701],
         [0.4472, 0.8944],
         [0.6402, 0.7682]],

        [[0.9939, 0.1104],
         [0.2425, 0.9701],
         [0.7894, 0.6139]]])
torch.Size([3, 3, 2])


### Swin Transformer Block

In [259]:
# output and input of the SWIN transformer block is the same 
# after patch merging, the feature set is sent to the STB, and then the input and the output of the STB would be the same, and then, 
# patch merging will again reduce the size of th feature

# so, PM is the one down-scaling the feature 
class SwinBlock(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, shifted, 
                 window_size, relative_pos_embedding):
        
        super().__init__()
        
        # Residual connection after the layer normalization and the Attention
        self.attention_block = Residual( # Residual on the whole block... 
            PostNorm(  # Layer normalization, we send the dim size and after computoing the attention
                dim,   # which dim of the input to the layer norm 
                
                WindowAttention(  # HOW IT WORKS
                    dim=dim, 
                    heads=heads, 
                    head_dim=head_dim, 
                    window_size=window_size, 
                    shifted=shifted,
                    relative_pos_embedding=relative_pos_embedding
                )
            )
        ) # MLP is the FeedForward   # we do residual connection and before that in the block, we have LN (Layer Normalization)
        """
        W-MSA: 
            [LN, W-MSA] -> [LN, MLP]
        SW-MSA:  
            [LN, SW-MSA] -> [LN, MLP]

        """
        # FeedForward after the first block according to the block in the paper 
        self.mlp_block = Residual(PostNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))
        # residual -> PreNorm -> FeedForward
    def forward(self, x):
        x = self.attention_block(x)  
        x = self.mlp_block(x)
        return x

In [260]:
class StageModule(nn.Module):
    def __init__(self, 
                 in_channels, 
                 hidden_dimension, 
                 layers, 
                 downscaling_factor, 
                 num_heads, 
                 head_dim, 
                 window_size, 
                 relative_pos_embedding
                 ):
        super().__init__()
        
        # for SwinTrans-T: {2, 2, 6, 2} for 4 stage blocks 
        # so, we atleast need 2 sub-blocks of a stage block, so, it should be in even 
        assert layers % 2 == 0 # stage layers need to be divisible by 2 for regular and shifted block 
        
        # Patch Partition  --  done before the 4 stages 
        # Dones at every stage 
        self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension, 
                                            downscaling_factor=downscaling_factor)
        
        # Stage 1: LE + STB   # only first stage has Linear Embedding 
        # Stage 2: PM + STB
        # Stage 3: PM + STB
        # Stage 4: PM + STB
        self.layers = nn.ModuleList([])
        
        for _ in range(layers // 2):  # we will go only one iteration. we will use the same n times, so the paramters don't get added up 
            self.layers.append(
                nn.ModuleList(
                    [
                        # W-MSA
                        SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                              shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),  # stage = False
                        # SW-MSA
                        SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                              shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding)                        
                    ]
                )
            )
    def forward(self, x):
        # path merging 
        print('Before path merging:', x.size())   # (1 batch_size, (3, 96, 192, 384) channels, (3, 224, 56, 28, 14) height, (224, 56, 28, 14) width)
        x = self.patch_partition(x)  # it will down scale the feature image 
        print('After patch merging:', x.size())
        # input and output of the SWIN transformer block is the same size 
        # patch_merging is the one which reduces the size of the feature image 
        
        """
        For the 1st stage, 2 blocks -> Patch Partition and Linear Embedding 
        (just resizes the feature)
        # In paper code, they used only one block for the stage 1 instead of 2 
        """
        # just pass through all the blocks 
        for regular_block, shifted_block in self.layers:
            
            # these two blocks, W-MSA and SW-MSA won't change the size of th feature    
            x = regular_block(x)
            x = shifted_block(x)
        
        return x.permute(0, 3, 1, 2) # batch, channels, h, w 

In [261]:
class SwinTransformer(nn.Module):
    def __init__(
        self,
        *, 
        hidden_dim, 
        layers, 
        heads, 
        channels=3, 
        num_classes=1000,   # for Imagenet as used in the paper 
        head_dim=32, 
        window_size=7, 
        downscaling_factors=(4, 2, 2, 2), 
        relative_pos_embedding=True
    ):
        super().__init__()
        
        """
        As we go forward with 4 stages, the size of the image is shrinking and the channels goes on increasing
        """
        # for processing the information in the image
        self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        
        # downscaluing_factor for the patch_merging 
        self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        
        self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        
        self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        
        # classification or decision head 
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )
    
    def forward(self, img):
        
        x = self.stage1(img)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)      # (batch, 768, 7, 7) --- 4 dim 
        
        # ?????
        x = x.mean(dim=[2, 3])     # (batch, 1, 768)    --- 3 dim --> 2 dim without batch 
        return self.mlp_head(x)  # give the linear form   # categrotical output 
    
    """this mlp_head could be changed to generate the text or LLM output and we can also embed the XAI into this.. 
    Possible"""
    

SWIN- 

T: C=96, layers = {2, 2, 6, 2}

S: C=96, layers = {2, 2, 18, 2}

B: C=128, layers = {2, 2, 18, 2}

L: C=192, layers = {2, 2, 18, 2}

In [262]:
def swin_t(
    hidden_dim=96,  # no of channels at begning and then 2C, 4C, and 8C 
    layers=(2, 2, 6, 2),  # in the paper 
    heads=(3, 6, 12, 24),
    **kwargs
):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

def swin_s(
    hidden_dim=96, 
    layers=(2, 2, 18, 2),   
    heads=(3, 6, 12, 24),   
    **kwargs
):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

def swin_b(
    hidden_dim=128, 
    layers=(2, 2, 18, 2), 
    heads=(4, 8, 16, 32),
    **kwargs
):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

def swin_l(
    hidden_dim=192,  
    layers=(2, 2, 18, 2),  
    heads=(6, 12, 24, 48),
    **kwargs
):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)



### Inference

In [266]:
net = swin_t(
    hidden_dim=96, # start channel 
    layers=(2, 2, 6, 2), # 4 stages
    heads=(3, 6, 12, 24),
    channels=3, 
    num_classes=3,
    head_dim=32,
    window_size=7,
    downscaling_factors=(4, 2, 2, 2),
    relative_pos_embedding=True
)

print(net)

SwinTransformer(
  (stage1): StageModule(
    (patch_partition): PatchMerging(
      (patch_merge): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    )
    (layers): ModuleList(
      (0): ModuleList(
        (0): SwinBlock(
          (attention_block): Residual(
            (fn): PostNorm(
              (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (fn): WindowAttention(
                (to_qkv): Linear(in_features=96, out_features=288, bias=False)
                (to_out): Linear(in_features=96, out_features=96, bias=True)
              )
            )
          )
          (mlp_block): Residual(
            (fn): PostNorm(
              (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (fn): FeedForward(
                (net): Sequential(
                  (0): Linear(in_features=96, out_features=384, bias=True)
                  (1): GELU(approximate='none')
                  (2): Linear(in_features=384, out_features=96, bias

In [275]:
dummy_x = torch.randn(1, 3, 224, 224)  # dummy input image tensor 
logits = net(dummy_x) # (1, 3) # as it would return the softmax 
print(logits)

Before path merging: torch.Size([1, 3, 224, 224])
After patch merging: torch.Size([1, 56, 56, 96])
Before path merging: torch.Size([1, 96, 56, 56])
After patch merging: torch.Size([1, 28, 28, 192])
Before path merging: torch.Size([1, 192, 28, 28])
After patch merging: torch.Size([1, 14, 14, 384])
Before path merging: torch.Size([1, 384, 14, 14])
After patch merging: torch.Size([1, 7, 7, 768])
tensor([[ 0.3548, -0.2833, -0.5480]], grad_fn=<AddmmBackward0>)
