In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [8]:
# reshaping functions

def to_mlp(x):
    B, C, H, W = x.size()
    return x.reshape(B*H*W, C)

def to_chw(x, dims):
    B, C, H, W = dims
    return x.reshape(B, C, H, W)

In [15]:
x = torch.randn(size=[32, 256, 64,64])

x.transpose(1, 2).shape

torch.Size([32, 64, 256, 64])

In [3]:
# overlap patch embedding

class OverlapPatchEmbedding(nn.Module):
    def __init__(self, in_channels, embbed_dim):
        super(OverlapPatchEmbedding, self).__init__()

        # input: [B, C, H, W]
        self.c_in = in_channels
        self.c_out = embbed_dim

        self.encode = nn.Conv2d(in_channels=self.c_in, out_channels=self.c_out, kernel_size=7, stride=4, padding=3)

        self.activation = nn.GELU() # added after, maybe can lead to more learning capacity

        self.decode = nn.ConvTranspose2d(in_channels=self.c_out, out_channels=self.c_out, kernel_size=6, stride=4, padding=1)

    
    def forward(self, x):
        return self.decode(self.activation(self.encode(x)))

In [18]:
# self-attention - transformer block

class SelfAttention(nn.Module):
    def __init__(self, embedd_dim):
        super(SelfAttention, self).__init__()

        self.query_weight = nn.Linear(in_features=embedd_dim, out_features=embedd_dim)
        self.key_weight = nn.Linear(in_features=embedd_dim, out_features=embedd_dim)
        self.value_weight = nn.Linear(in_features=embedd_dim, out_features=embedd_dim)

    
    def forward(self, x):
        B, C, H, W = x.size()

        x = x.reshape(B, H*W, C)

        # query, key and value tensors
        q = self.query_weight(x)
        k = self.key_weight(x)
        v = self.value_weight(x)

        scaled = torch.bmm(q, k.transpose(1, 2))

        scaled = F.softmax(scaled / math.sqrt(C), dim=-1)

        attention = torch.bmm(scaled, v)

        return attention.reshape(B, C, H, W)





In [19]:
test = SelfAttention(256)

test(x).shape

torch.Size([32, 256, 64, 64])