In [1]:
import collections
from itertools import repeat

import torch.nn as nn
import torch
import math

from einops import rearrange

In [2]:
class MHLinear(nn.Module): # Multihead at output
    def __init__(self, in_features, out_features, head, bias=True):
        super(MHLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.head = head
        self.weight = nn.Parameter(torch.Tensor(out_features,in_features,head))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features,head))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        weight = rearrange(self.weight, 'q e h -> (h q) e')
        bias = rearrange(self.bias, 'q h -> (h q)')
        return nn.functional.linear(input, weight, bias)
    
class MHTLinear(nn.Module): # Multihead at input
    def __init__(self, in_features, out_features, head, bias=True):
        super(MHTLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.head = head
        self.weight = nn.Parameter(torch.Tensor(out_features,in_features,head))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        weight = rearrange(self.weight, 'e v h -> e (h v)')
        return nn.functional.linear(input, weight, self.bias)

## QK

In [125]:
class QK(nn.Module):
    def __init__(self, emb, qk, head, qkv_bias=True):
        super().__init__()
        self.head = head
        self.Q = MHLinear(emb, qk, head, bias=qkv_bias)
        self.K = MHLinear(emb, qk, head, bias=qkv_bias)
        self.scale = head ** -0.5
        
    def forward(self, x):
        B, N, C = x.shape
        q = self.Q(x)
        k = self.K(x)
        print(q.shape, k.shape)
        x = torch.cat((q,k),-1)
        qk = x.reshape(B, N, 2, self.head, -1).permute(2, 0, 3, 1, 4)
        q, k = qk.unbind(0)
        
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        return attn

In [126]:
dummy_QK_input = torch.zeros((128, 198, 768))
dummy_QK_input = dummy_QK_input.cuda()

In [127]:
dummy_QK_input.dtype

torch.float32

In [128]:
QK_model = QK(emb=768, qk=64, head=12)
QK_model = QK_model.cuda()

In [132]:
QK_model.Q.weight.shape

torch.Size([64, 768, 12])

In [129]:
attn = QK_model(dummy_QK_input)

torch.Size([128, 198, 768]) torch.Size([128, 198, 768])


In [83]:
attn.shape

torch.Size([128, 12, 198, 198])

## V

In [65]:
# class V(nn.Module):
#     def __init__(self, emb, v, head, qkv_bias=True):
#         super().__init__()
#         self.V = MHLinear(emb, v, head, bias=qkv_bias)
        
#     def forward(self, x):
#         v = self.V(x)
#         return v

In [82]:
# dummy_V_input = torch.zeros((128, 198, 768))
# dummy_V_input = dummy_V_input.cuda()

In [83]:
# dummy_QK_input.dtype

torch.float32

In [84]:
# V_model = V(emb=768, v=64, head=12)
# V_model = V_model.cuda()

In [85]:
# v = V_model(dummy_V_input)

In [86]:
# v.shape

torch.Size([128, 198, 768])

## V + PROJ

In [99]:
class V_AND_PROJ(nn.Module):
    def __init__(self, emb, v, head, qkv_bias=True):
        super().__init__()
        self.V = MHLinear(emb, v, head, bias=qkv_bias)
        self.proj = MHTLinear(v, emb, head)
        self.head = head
        
    def forward(self, x, attn):
        B, N, C = x.shape
#         print(x.shape)
        v = self.V(x)
        v = v.reshape(B, N, 1, self.head, -1).permute(2, 0, 3, 1, 4)
        v = v[0]
        print(v.shape)
        print(attn.shape)
        x = attn @ v
        print(x.shape)
#         print(x.transpose(1, 2).shape)
#         print(B, N, C)
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

In [100]:
dummy_V_input = torch.zeros((256, 198, 16))
dummy_V_input = dummy_V_input.cuda()
dummy_attn = torch.zeros((256, 12, 198, 198))
dummy_attn = dummy_attn.cuda()

In [101]:
# dummy_V_input = torch.zeros((576, 198, 768))
# dummy_V_input = dummy_V_input.cuda()
# dummy_attn = torch.zeros((576, 198, 198))
# dummy_attn = dummy_attn.cuda()

In [102]:
# V_AND_PROJ_MODEL = V_AND_PROJ(emb=768, v=64, head=12)
# V_AND_PROJ_MODEL = V_AND_PROJ_MODEL.cuda()

In [103]:
V_AND_PROJ_MODEL = V_AND_PROJ(emb=16, v=64, head=12)
V_AND_PROJ_MODEL = V_AND_PROJ_MODEL.cuda()

In [104]:
V_AND_PROJ_MODEL.V.weight.shape

torch.Size([64, 16, 12])

In [105]:
x = V_AND_PROJ_MODEL(dummy_V_input, dummy_attn)

torch.Size([256, 12, 198, 64])
torch.Size([256, 12, 198, 198])
torch.Size([256, 12, 198, 64])


RuntimeError: shape '[256, 198, 16]' is invalid for input of size 38928384

In [102]:
x.shape

torch.Size([128, 198, 768])

## MLP

In [100]:
# From PyTorch internals
def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))
    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.GELU,
            norm_layer=None,
            bias=True,
            drop=0.,
            use_conv=False,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)
        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear

        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.norm(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

In [105]:
MLP_MODEL = Mlp(768, 3072)
MLP_MODEL = MLP_MODEL.cuda()

In [108]:
MLP_MODEL

Mlp(
  (fc1): Linear(in_features=768, out_features=3072, bias=True)
  (act): GELU()
  (drop1): Dropout(p=0.0, inplace=False)
  (norm): Identity()
  (fc2): Linear(in_features=3072, out_features=768, bias=True)
  (drop2): Dropout(p=0.0, inplace=False)
)

In [106]:
dummy_mlp_input = torch.zeros((128, 198, 768))
dummy_mlp_input = dummy_mlp_input.cuda()

In [107]:
MLP_MODEL(dummy_mlp_input)

tensor([[[ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         ...,
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202]],

        [[ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         ...,
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202]],

        [[ 0.0275, -0.0090,  0.0120,  ...,  0.0070,  0.0174, -0.0202],
         [ 0.0275, -0.0090,  0.0120,  ...,  0

## Timm

In [120]:
class TEST(nn.Module):
    def __init__(self, emb, v, head, qkv_bias=True):
        super().__init__()
#         self.qkv = nn.Linear(emb, emb * 3, bias=qkv_bias)
        self.qkv = nn.Linear(emb*head, emb*head*3, bias=qkv_bias)
        self.V = MHLinear(emb, v, head, bias=qkv_bias)
        self.proj = MHTLinear(v, emb, head)
        self.head = head
        
    def forward(self, x, attn):
        B, N, C = x.shape
        qkv = self.qkv(x)
        print(qkv.shape)
        qkv = qkv.reshape(B, N, 3, self.head, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        print(attn.shape)
        print(v.shape)
        x = attn @ v
# #         print(x.shape)
#         v = self.V(x)
#         v = v.reshape(B, N, 1, self.head, -1).permute(2, 0, 3, 1, 4)
#         v = v[0]
#         print(v.shape)
#         print(attn.shape)
#         x = attn @ v
#         print(x.shape)
# #         print(x.transpose(1, 2).shape)
# #         print(B, N, C)
#         x = x.transpose(1, 2).reshape(B, N, C)
#         x = self.proj(x)
#         return x

In [121]:
dummy_V_input = torch.zeros((256, 198, 16))
dummy_V_input = dummy_V_input.cuda()
dummy_attn = torch.zeros((256, 12, 198, 198))
dummy_attn = dummy_attn.cuda()

In [122]:
TEST_MODEL = TEST(emb=16, v=64, head=12)
TEST_MODEL = TEST_MODEL.cuda()

In [123]:
TEST_MODEL(dummy_V_input, dummy_attn)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (50688x16 and 192x576)

In [119]:
256 * 198 * 3 * 12

1824768