In [1]:
import collections
from itertools import repeat

import torch.nn as nn
import torch
import math

from einops import rearrange

In [2]:
class QKV_s(nn.Module):
    def __init__(self, emb, qk, v, qkv_bias=True):
        super().__init__()

        self.Q = nn.Linear(emb, qk, bias=qkv_bias)
        self.K = nn.Linear(emb, qk, bias=qkv_bias)
        self.V = nn.Linear(emb, v, bias=qkv_bias)
        self.token_mask = nn.Parameter(torch.ones(198, 1))

    def forward(self, x):
        q = self.Q(x)*self.token_mask
        k = self.K(x)*self.token_mask
        v = self.V(x)*self.token_mask
        return q,k,v

class ATT(nn.Module):
    def __init__(self, attn_drop=0., scale=None):
        super().__init__()
        self.attn_drop = nn.Dropout(attn_drop)
        self.scale = scale

    def forward(self, q, k):
        attn_r = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn_r.softmax(dim=-1)
        attn = self.attn_drop(attn)
        return attn

class PROJ(nn.Module):
    def __init__(self, v, dim, proj_drop=0.):
        super().__init__()
        self.v = v
        self.proj = nn.Linear(v, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, attn, v):
        x = (attn @ v).transpose(1, 2).reshape(-1, 198, self.v)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, qk, v, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        
        head_dim = dim // self.num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = 64 ** -0.5

        self.qkv = QKV_s(dim, qk, v)
        self.att = ATT(attn_drop,self.scale)
        self.proj = PROJ(v, dim, proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        q, k, v = self.qkv(x)
        qk_dim = q.shape[2]
        v_dim = v.shape[2]
        q = q.reshape(B, N, self.num_heads, qk_dim // self.num_heads).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.num_heads, qk_dim // self.num_heads).permute(0, 2, 1, 3)
        v = v.reshape(B, N, self.num_heads, v_dim // self.num_heads).permute(0, 2, 1, 3)

        attn = self.att(q,k)
        x = self.proj(attn,v)
        
        return x

## QK

In [3]:
class QK(nn.Module):
    def __init__(self, emb, qk, head, qkv_bias=True):
        super().__init__()
        self.head = head
        self.qk_dim = qk * head
        self.Q = nn.Linear(emb, self.qk_dim, bias=qkv_bias)
        self.K = nn.Linear(emb, self.qk_dim, bias=qkv_bias)
        self.scale = head ** -0.5
        
    def forward(self, x):
        B, N, C = x.shape
        q = self.Q(x)
        k = self.K(x)
        print(q.shape, k.shape)
        qk_dim = q.shape[2]
        q = q.reshape(B, N, self.head, qk_dim // self.head).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.head, qk_dim // self.head).permute(0, 2, 1, 3)
        attn_r = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn_r.softmax(dim=-1)
        return attn

In [4]:
dummy_QK_input = torch.zeros((128, 198, 768))
dummy_QK_input = dummy_QK_input.cuda()

In [5]:
dummy_QK_input.dtype

torch.float32

In [6]:
QK_model = QK(emb=768, qk=64, head=12)
QK_model = QK_model.cuda()

In [8]:
attn = QK_model(dummy_QK_input)

torch.Size([128, 198, 768]) torch.Size([128, 198, 768])


In [9]:
attn.shape

torch.Size([128, 12, 198, 198])

## V + PROJ

In [23]:
class V_AND_PROJ(nn.Module):
    def __init__(self, emb, v, head, qkv_bias=True):
        super().__init__()
        self.v_dim = v * head
        self.V = nn.Linear(emb, self.v_dim, bias=qkv_bias)
        self.proj = nn.Linear(self.v_dim, emb)
        self.head = head
        
    def forward(self, x, attn):
        B, N, C = x.shape
        v = self.V(x)
        v = v.reshape(B, N, self.head, self.v_dim // self.head).permute(0, 2, 1, 3)
        x = (attn @ v).transpose(1, 2).reshape(-1, 198, self.v_dim)
        x = self.proj(x)
        return x

In [24]:
dummy_V_input = torch.zeros((256, 198, 16))
dummy_V_input = dummy_V_input.cuda()
dummy_attn = torch.zeros((256, 12, 198, 198))
dummy_attn = dummy_attn.cuda()

In [25]:
V_AND_PROJ_MODEL = V_AND_PROJ(emb=16, v=64, head=12)
V_AND_PROJ_MODEL = V_AND_PROJ_MODEL.cuda()

In [26]:
x = V_AND_PROJ_MODEL(dummy_V_input, dummy_attn)

In [27]:
x.shape

torch.Size([256, 198, 16])

## MLP

In [100]:
# From PyTorch internals
def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))
    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.GELU,
            norm_layer=None,
            bias=True,
            drop=0.,
            use_conv=False,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)
        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear

        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.norm(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

In [105]:
MLP_MODEL = Mlp(768, 3072)
MLP_MODEL = MLP_MODEL.cuda()

In [108]:
MLP_MODEL

Mlp(
  (fc1): Linear(in_features=768, out_features=3072, bias=True)
  (act): GELU()
  (drop1): Dropout(p=0.0, inplace=False)
  (norm): Identity()
  (fc2): Linear(in_features=3072, out_features=768, bias=True)
  (drop2): Dropout(p=0.0, inplace=False)
)

In [106]:
dummy_mlp_input = torch.zeros((128, 198, 768))
dummy_mlp_input = dummy_mlp_input.cuda()

In [107]:
MLP_MODEL(dummy_mlp_input)

tensor([[[ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         ...,
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202]],

        [[ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         ...,
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202]],

        [[ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0

## Timm

In [120]:
class TEST(nn.Module):
    def __init__(self, emb, v, head, qkv_bias=True):
        super().__init__()
#         self.qkv = nn.Linear(emb, emb * 3, bias=qkv_bias)
        self.qkv = nn.Linear(emb*head, emb*head*3, bias=qkv_bias)
        self.V = MHLinear(emb, v, head, bias=qkv_bias)
        self.proj = MHTLinear(v, emb, head)
        self.head = head
        
    def forward(self, x, attn):
        B, N, C = x.shape
        qkv = self.qkv(x)
        print(qkv.shape)
        qkv = qkv.reshape(B, N, 3, self.head, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        print(attn.shape)
        print(v.shape)
        x = attn @ v
# #         print(x.shape)
#         v = self.V(x)
#         v = v.reshape(B, N, 1, self.head, -1).permute(2, 0, 3, 1, 4)
#         v = v[0]
#         print(v.shape)
#         print(attn.shape)
#         x = attn @ v
#         print(x.shape)
# #         print(x.transpose(1, 2).shape)
# #         print(B, N, C)
#         x = x.transpose(1, 2).reshape(B, N, C)
#         x = self.proj(x)
#         return x

In [121]:
dummy_V_input = torch.zeros((256, 198, 16))
dummy_V_input = dummy_V_input.cuda()
dummy_attn = torch.zeros((256, 12, 198, 198))
dummy_attn = dummy_attn.cuda()

In [122]:
TEST_MODEL = TEST(emb=16, v=64, head=12)
TEST_MODEL = TEST_MODEL.cuda()

In [123]:
TEST_MODEL(dummy_V_input, dummy_attn)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (50688x16 and 192x576)

In [119]:
256 * 198 * 3 * 12

1824768