In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [3]:
# reshaping functions

def to_mlp(x):
    B, C, H, W = x.size()
    return x.reshape(B*H*W, C)

def to_chw(x, dims):
    B, C, H, W = dims
    return x.reshape(B, C, H, W)

In [4]:
x = torch.randn(size=[32, 256, 64,64])

x.transpose(1, 2).shape

torch.Size([32, 64, 256, 64])

In [5]:
# overlap patch embedding

class OverlapPatchEmbedding(nn.Module):
    def __init__(self, in_channels, embbed_dim):
        super(OverlapPatchEmbedding, self).__init__()

        # input: [B, C, H, W]
        self.c_in = in_channels
        self.c_out = embbed_dim

        self.encode = nn.Conv2d(in_channels=self.c_in, out_channels=self.c_out, kernel_size=7, stride=4, padding=3)

        self.activation = nn.GELU() # added after, maybe can lead to more learning capacity

        self.decode = nn.ConvTranspose2d(in_channels=self.c_out, out_channels=self.c_out, kernel_size=6, stride=4, padding=1)

    
    def forward(self, x):
        return self.decode(self.activation(self.encode(x)))

In [6]:
# self-attention - transformer block

class SelfAttention(nn.Module):
    def __init__(self, embedd_dim):
        super(SelfAttention, self).__init__()

        self.query_weight = nn.Linear(in_features=embedd_dim, out_features=embedd_dim)
        self.key_weight = nn.Linear(in_features=embedd_dim, out_features=embedd_dim)
        self.value_weight = nn.Linear(in_features=embedd_dim, out_features=embedd_dim)

    
    def forward(self, x):
        B, C, H, W = x.size()

        x = x.reshape(B, H*W, C)

        # query, key and value tensors
        q = self.query_weight(x)
        k = self.key_weight(x)
        v = self.value_weight(x)

        scaled = torch.bmm(q, k.transpose(1, 2))

        scaled = F.softmax(scaled / math.sqrt(C), dim=-1)

        attention = torch.bmm(scaled, v)

        return attention.reshape(B, C, H, W)





In [7]:
class MIX_FFN(nn.Module):
    def __init__(self, embedd_dim):
        super(MIX_FFN, self).__init__()

        self.mlp1 = nn.Linear(embedd_dim, 2 * embedd_dim)
        
        self.conv = nn.Conv2d(2 * embedd_dim, 2 * embedd_dim, kernel_size=3, padding=1, stride=1)

        self.activation = nn.GELU()

        self.mlp2 = nn.Linear(2 * embedd_dim, embedd_dim)



    def forward(self, x):
        B, C, H , W = x.size()
        x_ = x

        # mlp1
        x = to_mlp(x)
        x = self.mlp1(x)

        # conv
        x = to_chw(x, [B, 2*C, H, W])
        x = self.conv(x)

        # gelu
        x = self.activation(x)

        # mlp2
        x = to_mlp(x)
        x = self.mlp2(x)
        x = to_chw(x, [B, C, H, W])

        return x + x_

In [8]:
ffn = MIX_FFN(32)
x = torch.randn(16, 32, 22,16)
ffn(x).shape

torch.Size([16, 32, 22, 16])

In [9]:
class OverlapPatchMerging(nn.Module):
    def __init__(self, embedd_dim):
        super(OverlapPatchMerging, self).__init__()
        # divide as dimensões espaciais por 2]
        # no artigo original, divide-se as dimensões espaciais por 4, mas por questão do formato dos dados do microCT, vamos dividir apenas por 2
        self.conv = nn.Conv2d(embedd_dim, 2*embedd_dim, kernel_size=7, stride=2, padding=3)

    def forward(self, x):
        return self.conv(x)
    

In [10]:
opm = OverlapPatchMerging(32)
opm(x).shape

torch.Size([16, 64, 11, 8])

In [11]:
class TransformerChain(nn.Module):
    def __init__(self, embedd_dim):
        super(TransformerChain, self).__init__()

        # 4 self-att
        # 4 mix-ffn

        self.att1 = SelfAttention(embedd_dim)
        self.ffn1 = MIX_FFN(embedd_dim)

        self.att2 = SelfAttention(embedd_dim)
        self.ffn2 = MIX_FFN(embedd_dim)

        self.att3 = SelfAttention(embedd_dim)
        self.ffn3 = MIX_FFN(embedd_dim)

        self.att4 = SelfAttention(embedd_dim)
        self.ffn4 = MIX_FFN(embedd_dim)

    def forward(self, x):
        return self.ffn4(self.att4(self.ffn3(self.att3(self.ffn2(self.att2(self.ffn1(self.att1(x))))))))

In [12]:
t = TransformerChain(2)
x = torch.randn(16, 2, 10, 5)
t(x).shape

torch.Size([16, 2, 10, 5])