In [1]:
import torch
import torch.nn as nn
from timm.models.layers import DropPath
import torch.nn.functional as F
# two arguments are added compared to vanilla attention from ./models/token_transformer.py
# num_tokens: number of tokens
# head_separate: If True, we define separate weight matrices for different heads.
#                Else, we define one single weight matrix for different heads.

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def calculate_local_3x3_index_list(token_map_height, token_map_width):
    num_tokens = token_map_height * token_map_width
    local_pos_list = []
    all_local_index_list = []

    for row in range(token_map_height):
        for col in range(token_map_width):
            local_pos = []
            for i in [-1, 0, 1]:
                for j in [-1, 0, 1]:
                    local_pos.append([row + i, col + j])
            local_pos_list.append(local_pos)

    for index in range(num_tokens):
        each_local_index_list = []
        for local_pos in local_pos_list[index]:
            if local_pos[0] in range(token_map_height) and local_pos[1] in range(token_map_width):
                local_index = local_pos[0] * token_map_width + local_pos[1]
                each_local_index_list.append(local_index)
            else:
                local_index = num_tokens
                each_local_index_list.append(local_index)
        all_local_index_list.append(each_local_index_list)

    return all_local_index_list


class Kernel_3x3_Convolutional_Attention_t2t(nn.Module):
    def __init__(self, dim, num_tokens, token_map_height, token_map_width, head_separate=False, abs_kernel_size=9, num_heads=8, in_dim=None, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.num_tokens = num_tokens
        self.token_map_height = token_map_height
        self.token_map_width = token_map_width
        self.in_dim = in_dim
        self.head_sep = head_separate
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.abs_kernel_size = 9
        self.qkv = nn.Linear(dim, in_dim * 3, bias=qkv_bias)
        self.local_attn_drop = nn.Dropout(attn_drop)
        self.global_attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(in_dim, in_dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # we can initialize the mask here by replacing torch.ones
        self.local_3x3_kernel_weight = nn.Parameter(torch.ones(num_tokens, abs_kernel_size), requires_grad=False)
        self.calculate_local_3x3_index_tensor()

    def forward(self, x):
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.in_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        # shape of Q, K and V (B, self.num_heads, N, self.in_dim)

        attn = (q @ k.transpose(-2, -1)) * self.scale

        padded_attn = F.pad(input=attn, pad=[0, 1], mode='constant', value=0)

        filterd_attention_map = self.batch_attention_filter(padded_attn)

        weighted_filterd_attention_map = torch.mul(filterd_attention_map, self.local_3x3_kernel_weight)

        index_expand = self.local_index.unsqueeze(0).unsqueeze(0).expand(B, self.num_heads, self.num_tokens, self.abs_kernel_size)

        local_attn = torch.zeros((B, self.num_heads, self.num_tokens, self.num_tokens+1)).scatter_(-1, index_expand.to(torch.int64), weighted_filterd_attention_map)

        local_attn = local_attn[:, :, :, :-1]

        local_attn = local_attn.softmax(dim=-1)
        global_attn = attn.softmax(dim=-1)
        # shape of attn : B, num_heads, num_tokens, num_tokens
        local_attn = self.local_attn_drop(local_attn)
        global_attn = self.global_attn_drop(global_attn)

        # filtered_v = self.v_filter(v)

        x_local = (local_attn @ v).transpose(1, 2).reshape(B, N, self.in_dim)
        x_global = (global_attn @ v).transpose(1, 2).reshape(B, N, self.in_dim)
        x = x_local + x_global

        x = self.proj(x)
        x = self.proj_drop(x)

        # skip connection
        x = v.squeeze(1) + x   # because the original x has different size with current x, use v to do skip connection

        return x

    def batch_attention_filter(self, attention_map):
        local_index = self.local_index
        num_batch, num_heads, N1, N2 = attention_map.shape
        index_expand = local_index.unsqueeze(0).expand(num_batch * num_heads, local_index.shape[0], local_index.shape[1])
        filterd_attention_map = torch.gather(attention_map.view(num_batch * num_heads, N1, N2), 2, index_expand.to(torch.int64))
        return filterd_attention_map.view(num_batch, num_heads, filterd_attention_map.shape[-2], filterd_attention_map.shape[-1])

    def calculate_local_3x3_index_tensor(self):
        local_index_tensor = torch.zeros([self.num_tokens, self.abs_kernel_size])

        all_local_index_list = calculate_local_3x3_index_list(self.token_map_height, self.token_map_width)

        for id in range(0, self.num_tokens):
            for local_id in range(self.abs_kernel_size):
                local_index_tensor[id, local_id] = all_local_index_list[id][local_id]

        self.local_index = local_index_tensor


    def calculate_local_3x3_index_list(self):
        num_tokens = self.token_map_height * self.token_map_width
        local_pos_list = []
        all_local_index_list = []

        for row in range(self.token_map_height):
            for col in range(self.token_map_width):
                local_pos = []
                for i in [-1, 0, 1]:
                    for j in [-1, 0, 1]:
                        local_pos.append([row + i, col + j])
                local_pos_list.append(local_pos)

        for index in range(num_tokens):
            each_local_index_list = []
            for local_pos in local_pos_list[index]:
                if local_pos[0] in range(self.token_map_height) and local_pos[1] in range(self.token_map_width):
                    local_index = local_pos[0] * self.token_map_width + local_pos[1]
                    each_local_index_list.append(local_index)

            all_local_index_list.append(each_local_index_list)

        return all_local_index_list

In [2]:
token_dim = 64
k_attention_1 = Kernel_3x3_Convolutional_Attention_t2t(dim=3 * 7 * 7,
                                             num_tokens=56 * 56, token_map_height=56, token_map_width=56, 
                                             head_separate=False,in_dim=token_dim, num_heads=1)
soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))

In [3]:
x=torch.randn(10, 3, 224, 224)
x=soft_split0(x)
print(x.shape)
x=x.transpose(1, 2)
print(x.shape)
x=k_attention_1(x)
print(x.shape)

torch.Size([10, 147, 3136])
torch.Size([10, 3136, 147])
torch.Size([10, 3136, 64])


In [13]:
import torch.nn.functional as F
data = torch.ones(4, 4)
# pad(left, right, top, bottom)
new_data = F.pad(input=data, pad=[0, 1, 0, 0], mode='constant', value=0)
print(new_data)

tensor([[1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 0.]])


In [14]:
print(new_data[0])

tensor([1., 1., 1., 1., 0.])


In [4]:
a= torch.tensor([1,2,3,4])
print(a[:-1])

tensor([1, 2, 3])
