In [1]:
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import timm

In [2]:
img = torch.Tensor(np.random.random((4,3,384,384)))
img.shape

torch.Size([4, 3, 384, 384])

In [3]:
class Read_ignore(nn.Module):
    def __init__(self, start_index=1):
        super(Read_ignore, self).__init__()
        self.start_index = start_index

    def forward(self, x):
        return x[:, self.start_index:]


class Read_add(nn.Module):
    def __init__(self, start_index=1):
        super(Read_add, self).__init__()
        self.start_index = start_index

    def forward(self, x):
        if self.start_index == 2:
            readout = (x[:, 0] + x[:, 1]) / 2
        else:
            readout = x[:, 0]
        return x[:, self.start_index :] + readout.unsqueeze(1)


class Read_projection(nn.Module):
    def __init__(self, in_features, start_index=1):
        super(Read_projection, self).__init__()
        self.start_index = start_index
        self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())

    def forward(self, x):
        readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
        features = torch.cat((x[:, self.start_index :], readout), -1)
        return self.project(features)

In [33]:
class FocusOnDepth(nn.Module):
    def __init__(self, 
                 image_size = (3, 384, 384), 
                 patch_size = (16, 16), 
                 emb_dim = 1024, 
                 read = 'projection',
                 num_layers_encoder = 24,
                 hooks = [0,1,2,3],
                 nhead = 16,
                 transformer_dropout = 0):
        """
        Focus on Depth - Large
        image_size : (c, h, w)
        patch_size : (h, w)
        emb_dim <=> D (in the paper)
        read : {"ignore", "add", "projection"}
        """
        super().__init__()
        
        #Splitting img into patches
        channels, image_height, image_width = image_size
        patch_height, patch_width = patch_size
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_height // patch_height) * \
                      (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, emb_dim),
        )
        #Embedding
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))
        
        #Transformer
        self.activation = {}
        self.hooks = hooks
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dropout=transformer_dropout, dim_feedforward=emb_dim*4)
        self.transformer_encoders = nn.TransformerEncoder(encoder_layer, num_layers=num_layers_encoder)
        #Register hooks
        self.get_layers_from_hooks(self.hooks)

        #Concat after read
        self.concat = nn.Sequential(
            Rearrange('b (h w) c -> b c h w', c=emb_dim,  h=(image_height // patch_height), w=(image_width // patch_width))
        )

        #Read
        self.read = Read_ignore()
        if read == 'add':
            self.read = Read_add()
        elif read == 'projection':
            self.read = Read_projection(emb_dim)
    
    def get_layers_from_hooks(self, hooks):
        def get_activation(name):
            def hook(model, input, output):
                self.activation[name] = output.detach()
            return hook
        for h in hooks:
            self.transformer_encoders.layers[h].register_forward_hook(get_activation('t'+str(h)))

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        print(x.shape)
        t = self.transformer_encoders(x)
        print(t.shape)
        act1 = self.activation['t1']
        act1 = self.read(act1)
        act1 = self.concat(act1)
        print(act1.shape)



In [34]:
model = FocusOnDepth((3,384,384), (16, 16), 128, "projection", 5, [1], 4)
model(img)
#model

torch.Size([4, 577, 128])
torch.Size([4, 577, 128])
torch.Size([4, 128, 24, 24])


In [18]:
model.activation

{'t1': tensor([[[ 0.5244,  0.6776,  0.9152,  ..., -1.4480,  0.3267, -2.3474],
          [ 1.3049, -1.7115, -0.2993,  ..., -1.0103, -0.7556,  0.3068],
          [-0.5465,  0.7846, -0.2911,  ...,  0.8926,  1.2563,  1.3280],
          ...,
          [ 1.3120, -1.8247, -0.4216,  ...,  1.1285,  0.2768,  1.0452],
          [ 0.1289,  1.7938, -0.1779,  ..., -0.4079,  0.5089,  0.8926],
          [ 0.3298,  1.3368, -0.1826,  ...,  1.6811, -1.1189,  0.5914]],
 
         [[ 0.5244,  0.6776,  0.9152,  ..., -1.4480,  0.3267, -2.3474],
          [ 1.4871, -1.9232, -0.0783,  ..., -0.9625, -0.7430,  0.1225],
          [-0.0629,  0.7538, -0.5719,  ...,  0.8150,  1.4043,  1.3633],
          ...,
          [ 1.3689, -1.7542, -0.3998,  ...,  1.2487,  0.3106,  1.0138],
          [ 0.1369,  1.4279, -0.2067,  ..., -0.1465,  0.4153,  0.9271],
          [ 0.2446,  1.2247, -0.1281,  ...,  1.7195, -1.1244,  0.5852]],
 
         [[ 0.5244,  0.6776,  0.9152,  ..., -1.4480,  0.3267, -2.3474],
          [ 1.1764, -1

In [6]:
encoder_layer = nn.TransformerEncoderLayer(d_model=1024, nhead=16, dropout=0, dim_feedforward=1024*4)
transformer_encoders = nn.TransformerEncoder(encoder_layer, num_layers=24)

TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
  )
  (linear1): Linear(in_features=1024, out_features=4096, bias=True)
  (dropout): Dropout(p=0, inplace=False)
  (linear2): Linear(in_features=4096, out_features=1024, bias=True)
  (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0, inplace=False)
  (dropout2): Dropout(p=0, inplace=False)
)