In [1]:
import math
from functools import partial
from dataclasses import dataclass

import torch
from torch import optim, nn, utils, Tensor
from torchvision.models import swin_v2_t, Swin_V2_T_Weights
from torchvision.ops import StochasticDepth, Permute, MLP
import torch.nn.functional as F

from torchvision.datasets import ImageNet, CocoDetection

from typing import Tuple, Callable, List, Optional
import lightning.pytorch as pl
from torchsummary import summary

In [15]:

def _get_relative_position_bias(
    relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int]
) -> torch.Tensor:
    N = window_size[0] * window_size[1]
    relative_position_bias = relative_position_bias_table[relative_position_index]  # type: ignore[index]
    relative_position_bias = relative_position_bias.view(N, N, -1)
    relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
    return relative_position_bias

class SwinAttention(nn.Module):

    def __init__(self,
        dim: int,
        window_size: List[int],
        shift_size: List[int],
        num_heads: int,
        attention_dropout: float,
        dropout: float,
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size

        self.query = self.build_query()
        self.key = self.build_key()
        self.value = self.build_value()
        self.linear = self.build_linear()

        self.attention_dropout = nn.Dropout(attention_dropout)
        self.dropout = nn.Dropout(dropout)

        self.define_relative_position_bias_table()
        self.define_relative_position_index()
        
        self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
        self.cpb_mlp = nn.Sequential(
            nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
        )

    def build_query(self):
        return nn.Linear(self.dim, self.dim)

    def build_key(self):
        return nn.Linear(self.dim, self.dim)
    
    def build_value(self):
        return nn.Linear(self.dim, self.dim)
    
    def build_linear(self):
        return nn.Linear(self.dim, self.dim)

    def _logits(self, q: Tensor, k: Tensor) -> Tensor:
        return F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)

    def _attention(self): ...
    
    def _post_attention(self, attn):
        logit_scale = torch.clamp(self.logit_scale, max=math.log(100.0)).exp()
        attn = attn * logit_scale
        return attn + self.get_relative_position_bias()
    
    # def _linear(self, v: Tensor) -> Tensor:
    #     return self.linear(v)

    def define_relative_position_bias_table(self):
        # get relative_coords_table
        relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
        relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
        relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
        relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2

        relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
        relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1

        relative_coords_table *= 8  # normalize to -8, 8
        relative_coords_table = (
            torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0
        )
        self.register_buffer("relative_coords_table", relative_coords_table)

    def define_relative_position_index(self):
        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1).flatten()  # Wh*Ww*Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

    def get_relative_position_bias(self) -> torch.Tensor:
        relative_position_bias = _get_relative_position_bias(
            self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads),
            self.relative_position_index,  # type: ignore[arg-type]
            self.window_size,
        )
        relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
        return relative_position_bias

    def forward(self, x: Tensor):
        B, H, W, C = x.size()
        print('input: ', x.shape)
        window_size = self.window_size
        pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
        pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
        x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
        _, pad_H, pad_W, _ = x.shape
        print('pad: ', x.shape)

        shift_size = self.shift_size.copy()
        # If window size is larger than feature size, there is no need to shift window
        if self.window_size[0] >= pad_H:
            shift_size[0] = 0
        if self.window_size[1] >= pad_W:
            shift_size[1] = 0
        # cyclic shift
        # if sum(shift_size) > 0:
        x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
        # print(x.shape)
        # partition windows
        num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
        print('num_windows', num_windows)
        x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
        print(x.shape)
        x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C)
        print(x.shape)
        
        q, k, v = self.query(x), self.key(x), self.value(x)
        q = q.reshape(x.size(0), x.size(1), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = k.reshape(x.size(0), x.size(1), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = v.reshape(x.size(0), x.size(1), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        attn = self._logits(q, k) # F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
        attn = self._post_attention(attn)
        
        if sum(shift_size) > 0:
            # generate attention mask
            attn_mask = x.new_zeros((pad_H, pad_W))
            h_slices = ((0, -self.window_size[0]), (-self.window_size[0], -shift_size[0]), (-shift_size[0], None))
            w_slices = ((0, -self.window_size[1]), (-self.window_size[1], -shift_size[1]), (-shift_size[1], None))
            count = 0
            for h in h_slices:
                for w in w_slices:
                    attn_mask[h[0] : h[1], w[0] : w[1]] = count
                    count += 1
            attn_mask = attn_mask.view(pad_H // self.window_size[0], self.window_size[0], pad_W // self.window_size[1], self.window_size[1])
            attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, self.window_size[0] * self.window_size[1])
            attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
            print('mask', attn_mask.shape)
            attn = attn.view(x.size(0) // num_windows, num_windows, self.num_heads, x.size(1), x.size(1))
            attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, x.size(1), x.size(1))

        attn = F.softmax(attn, dim=-1)
        attn = self.attention_dropout(attn)

        x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
        x = self.linear(x)
        x = self.dropout(x)

        # reverse windows
        print(x.shape)
        x = x.view(B, pad_H // self.window_size[0], pad_W // self.window_size[1], self.window_size[0], self.window_size[1], C)
        x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
        print(x.shape)

        # reverse cyclic shift
        if sum(shift_size) > 0:
            x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))

        # unpad features
        x = x[:, :H, :W, :].contiguous()
        return x


In [17]:
b, d = 5, 12
attention = SwinAttention(
    d, [4, 4], [2, 2], 1, 0.1, 0.1
)  # attention layer with cyclic shift
image = torch.rand(b, 14, 14, d)

out = attention(image)

input:  torch.Size([5, 14, 14, 12])
pad:  torch.Size([5, 16, 16, 12])
num_windows 16
torch.Size([5, 4, 4, 4, 4, 12])
torch.Size([80, 16, 12])
mask torch.Size([16, 16, 16])
torch.Size([80, 16, 12])
torch.Size([5, 16, 16, 12])
