In [1]:
import torch
from torch import nn
from typing import Callable, List, Optional, Tuple
import math
from timm.models.layers import trunc_normal_
import numpy as np

In [2]:
class VerboseNNModule(nn.Module):
    
    @staticmethod
    def get_readable_tensor_representation(name: str, tensor: torch.Tensor):
        st = (
            "(" + name + "): " + "tensor(" + str(tuple(tensor[1].shape)) + ", requires_grad=" + str(tensor[1].requires_grad) + ")\n"
            )
        return st
    
    def extra_repr(self) -> str:
        named_modules = set()
        for p in self.named_modules():
            named_modules.update(p[0])
        named_modules = list(named_modules)

        string_repr = ""
        for p in self.named_parameters():
            name = p[0].split(".")[0]
            if name in named_modules:
                string_repr += self.get_readable_tensor_representation(name, p)
        
        for p in self.named_buffers():
            name = p[0].split(".")[0]
            string_repr += self.get_readable_tensor_representation(name, p)
        
        return string_repr

class MyLayer(VerboseNNModule):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 2)
        self.register_buffer("running_mean", torch.zeros(5))

# Instantiate and print the model
model = MyLayer()
print(model)

MyLayer(
  (running_mean): tensor((5,), requires_grad=False)
  
  (linear): Linear(in_features=1, out_features=2, bias=True)
)


In [3]:
def build_causal_attention_mask(context_length):
    mask = torch.empty(context_length, context_length, requires_grad=False)
    mask.fill_(float("-inf"))
    mask.triu_(1)
    return mask

build_causal_attention_mask(5)

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [4]:
class TextPreprocessor(VerboseNNModule):
    def __init__(self, vocab_size: int, context_length: int, embed_dim: int, causual_mask: bool, 
                 supply_seq_len_to_head: bool = True, init_param_style: str = "openclip"):
        """
        `vocab_size`: Number of tokens in your vocabulary.                 the number of words in your text, so we can map nn.Embedding
	    `context_length`: Maximum number of tokens per input sequence.     usually: 77
	    `embed_dim`: Dimensionality of each token embedding.               usually: 768
        """

        super().__init__()
        
        self.vocab_size = vocab_size
        self.context_length = context_length
        self.causual_mask = causual_mask
        self.embed_dim = embed_dim
        self.supply_seq_len_to_head = supply_seq_len_to_head
        
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(
            torch.empty(1, context_length, embed_dim)
        )
        if causual_mask:
            mask = build_causal_attention_mask(context_length)
            self.register_buffer("mask", mask) # register the mask as a buffer so it can be moved to the right device
        
        self.init_parameters(init_param_style)

    @torch.no_grad()
    def init_parameters(self, init_param_style = "openclip"):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.pos_embed, std=0.02)

        # I did'nt use init_param_style as I was too lazy to implment [CLS]
    
    def forward(self, text):
        token_text = self.token_embedding(text)
        token_text = token_text + self.pos_embed
        
        return_dict = {
            "trunk": {
                "tokens": token_text
            },
            "head": {},
        }

        if self.supply_seq_len_to_head:
            text_lengths = text.argmax(dim = -1)
            #  hacky and non-standard way of getting the sequence length.
            return_dict["head"] = {
                "seq_len": text_lengths,
            }
        if self.causual_mask:
            return_dict["trunk"].update({"attn_mask": self.mask})
        
        return return_dict
    
vocab_size = 100
context_length = 77
embed_dim = 768

# Sample input: batch of 1, padded or truncated to 77 tokens
text = torch.randint(0, vocab_size, (2, context_length))  # shape [1, 77]

text_processor = TextPreprocessor(
    vocab_size=vocab_size,
    context_length=context_length,
    embed_dim=embed_dim,
    causual_mask=True,
    supply_seq_len_to_head=True
)

out = text_processor(text)
print(out["trunk"]["tokens"].shape)  # âžœ [2, 77, 768]

torch.Size([2, 77, 768])


In [5]:
out["trunk"]["attn_mask"].shape

torch.Size([77, 77])

In [6]:
out["head"]["seq_len"]

tensor([22, 72])

In [7]:
print(text_processor)

TextPreprocessor(
  (mask): tensor((77, 77), requires_grad=False)
  
  (token_embedding): Embedding(100, 768)
)


In [8]:
class PatchEmbedGeneric(nn.Module):
    def __init__(self, proj_stem, norm_layer: Optional[Callable] = None):
        super().__init__()

        if len(proj_stem) > 1:
            self.proj = nn.Sequential(*proj_stem)
        else:
            # Special case to be able to load pre-trained models that were
            # trained with a standard stem
            self.proj = proj_stem[0]
        self.norm_layer = norm_layer
    
    def get_patch_layout(self, image_size):
        with torch.no_grad():
            dummy_img = torch.zeros([1,] + image_size)      # 1, C, (T), H, W
            dummy_out = self.proj(dummy_img)
        # print(dummy_out.shape)
        embed_dim = dummy_out.shape[1]                    # `embed_dim`    = C        
        patch_layout = tuple(dummy_out.shape[2:])         # `patch_layout` = (T), H, W       
        num_patches = np.prod(patch_layout)               # `num_patches`  = (T) * H * W       
        return embed_dim, patch_layout, num_patches
    
    def forward(self, x: torch.Tensor):
        x = x.flatten(2)                                  # B, C, (T), H, W -> B, C, (T)*H*W
        x = x.transpose(1, 2)                             # B, C, (T)*H*W   -> B, (T)*H*W, C
        if self.norm_layer is not None:
            x = self.norm_layer(x)
        return x

## Testing
proj_stem = [
    nn.Linear(224, 224),
    nn.ReLU()
]

patch_embed = PatchEmbedGeneric(proj_stem)
x = torch.randn(2, 3, 224, 224)

out = patch_embed(x)
out.shape
 

torch.Size([2, 100352, 3])

In [9]:
x = torch.rand(3, 224, 224)
embed_dim, patch_layout, num_patches = patch_embed.get_patch_layout(list(x.shape))

print(f"Embed dim: {embed_dim}")
print(f"Patch layout: {patch_layout}")
print(f"Number of patches: {num_patches}")

Embed dim: 3
Patch layout: (224, 224)
Number of patches: 50176


In [10]:
def get_sinusoid_encoding_table(n_position, d_hid):
    """Sinusoid position encoding table"""

    # TODO: make it with torch instead of numpy
    def get_position_angle_vec(position):
        return [
            position / np.power(10000, 2 * (hid_j // 2) / d_hid)
            for hid_j in range(d_hid)
        ]

    sinusoid_table = np.array(
        [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
    )
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)

pos_encoding = get_sinusoid_encoding_table(n_position=196, d_hid=768)
print(pos_encoding.shape)  # Output: (1, 4, 8)
print(pos_encoding)


torch.Size([1, 196, 768])
tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  8.2843e-01,  ...,  1.0000e+00,
           1.0243e-04,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.2799e-01,  ...,  1.0000e+00,
           2.0486e-04,  1.0000e+00],
         ...,
         [-9.7846e-01, -2.0645e-01, -6.9584e-02,  ...,  9.9980e-01,
           1.9767e-02,  9.9980e-01],
         [-7.0239e-01,  7.1180e-01,  7.8745e-01,  ...,  9.9979e-01,
           1.9870e-02,  9.9980e-01],
         [ 2.1945e-01,  9.7562e-01,  9.5167e-01,  ...,  9.9979e-01,
           1.9972e-02,  9.9980e-01]]])


In [11]:
def interpolate_pos_encoding(npatch_per_image, pos_embed, first_patch_idx: int = 1): 
    # If CLS present first_patch_idx = 1

    assert first_patch_idx == 0 or first_patch_idx == 1, "CLS can be either present or not present"
    # assert 
    N = pos_embed.shape[1] - first_patch_idx             # If CLS is present tokens from the 1: to rest are actual stuff
    
    if npatch_per_image == N:
        return pos_embed

    class_emb = pos_embed[:, :first_patch_idx]
    pos_embed = pos_embed[:, first_patch_idx:]

    return torch.cat((class_emb, pos_embed), dim=1)

pos_embed = torch.rand(1, 196, 768)
interpolate_pos_encoding(npatch_per_image = 196, pos_embed = pos_embed, first_patch_idx=1).shape

torch.Size([1, 196, 768])

In [12]:
def _get_pos_embedding(npatch_per_image, pos_embed, first_patch_idx: int = 1):
    return interpolate_pos_encoding(npatch_per_image, pos_embed, first_patch_idx)
_get_pos_embedding(npatch_per_image = 196, pos_embed = pos_embed, first_patch_idx=1).shape

torch.Size([1, 196, 768])

In [13]:
class SpatioTemporal_posEmbeddingHelper(VerboseNNModule):
    def __init__(self, num_patches: int, num_cls_tokens: int, embed_dim: int, learnable: bool):
        super().__init__()
        self.num_patches = num_patches
        self.num_cls_tokens = num_cls_tokens
        self.embed_dim = embed_dim
        self.learnable = learnable

        self.num_tokens = num_patches + num_cls_tokens

        if learnable:
            self.pos_embed = nn.Parameter(
                                torch.zeros(1, self.num_tokens, embed_dim)
                            )
            trunc_normal_(self.pos_embed, std=0.02)

        else: self.register_buffer(
            "pos_embed", get_sinusoid_encoding_table(n_position = self.num_tokens, d_hid = embed_dim)
            )
    
    def get_pos_embedding(self, all_vision_tokens):
        pos_embed = _get_pos_embedding(
            npatch_per_image = all_vision_tokens.size(1) - self.num_cls_tokens,
            pos_embed=self.pos_embed,
            first_patch_idx=self.num_cls_tokens,
        )
        return pos_embed


In [None]:
num_patches = 16
num_cls_tokens = 1
embed_dim = 768

# Fake patch embeddings for 2 images
batch_size = 2
patch_embeddings = torch.randn(batch_size, num_patches, embed_dim)  # [2, 16, 768]

# CLS token (typically learned)
cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))  # [1, 1, 768]
cls_tokens = cls_token.expand(batch_size, -1, -1)       # [2, 1, 768]

# Combine CLS and patches
all_vision_tokens = torch.cat([cls_tokens, patch_embeddings], dim=1)  # [2, 17, 768]
all_vision_tokens.shape

torch.Size([2, 17, 768])

In [15]:
pos_helper = SpatioTemporal_posEmbeddingHelper(
    num_patches=num_patches,
    num_cls_tokens=num_cls_tokens,
    embed_dim=embed_dim,
    learnable=False  # or True if you want learnable positions
)

vision_input = torch.randn(batch_size, 3, 224, 224)  # dummy input
pos_embed = pos_helper.get_pos_embedding(all_vision_tokens = all_vision_tokens)  # [1, 17, 768] (broadcastable)
pos_embed.shape

torch.Size([1, 17, 768])

In [16]:
class RGBTProcessor(VerboseNNModule):
    def __init__(self, rgbt_stem: PatchEmbedGeneric, img_size: Tuple = [3, 224, 224],
                 num_cls_token: int = 1, pos_embed_fn: SpatioTemporal_posEmbeddingHelper = None, 
                 use_type_embed: bool = False, init_param_style: str = "openclip"):
        super().__init__()

        self.embed_dim, self.patches_layout, self.num_patches = rgbt_stem.get_patch_layout(img_size)
        self.num_cls_token = num_cls_token
        self.use_type_embed = use_type_embed
        self.init_param_style = init_param_style
        self.use_pos_embed = pos_embed_fn is not None
        self.rgbt_stem = rgbt_stem

        if self.use_pos_embed:
            self.pos_embed_helper = pos_embed_fn(
                num_patches = self.num_patches,
                num_cls_tokens = self.num_cls_token,
                embed_dim = self.embed_dim,
                learnable = True
            )
        
        if num_cls_token > 0:
            self.cls_tokens = nn.Parameter(
                torch.zeros(1, self.num_cls_token, self.embed_dim)
            )
        if self.use_type_embed: # The model learns to adjust type_embed so that it provides differentiation for different modalities
            self.type_embed = nn.Parameter(
                torch.zeros(1, 1, self.embed_dim)
            )
        
        self.init_parameters(init_param_style)
    
    @torch.no_grad()
    def init_parameters(self, parameter_style):
        if parameter_style == "openclip":
            # OpenCLIP style initialization
            scale = self.embed_dim ** -0.5
        
            if self.use_type_embed:
                nn.init.normal_(self.pos_embed_helper.pos_embed)
                self.pos_embed_helper.pos_embed *= scale
            
            if self.num_cls_token > 0:
                nn.init.normal_(self.cls_tokens)
                self.cls_tokens *= scale
        
        elif parameter_style == "vit":
            self.cls_tokens.data.fill_(0)
        
        else:
            raise ValueError(f"Unknown init {parameter_style}")
        
        if self.use_type_embed:
            nn.init.normal_(self.type_embed)
        
    def tokenize_input_and_cls_pos(self, input, stem):
        tokens = stem(input)
        assert tokens.ndim == 3
        assert tokens.shape[-1] == self.embed_dim

        B = tokens.shape[0] # batch size
        
        if self.num_cls_token > 0:
            cls_tokens = self.cls_tokens.expand(B, -1, -1)   # Making sure Batches are matching or shape mismatch might occur
            tokens = torch.cat([cls_tokens, tokens], dim=1)

        if self.use_pos_embed:
            pos_embed = self.pos_embed_helper.get_pos_embedding(all_vision_tokens = tokens)
            tokens = tokens + pos_embed
        
        if self.use_type_embed:
            tokens = tokens + self.type_embed.expand(B, -1, -1)
        
        return tokens
    
    def forward(self, vision = None):
        vision_tokens = self.tokenize_input_and_cls_pos(input = vision, stem = self.rgbt_stem)
        return_dict = {
                        "trunk": {
                            "tokens": vision_tokens
                        },
                        "head": {}
                    }
        return return_dict


In [17]:
num_patches = 16
num_cls_tokens = 1
embed_dim = 768

In [18]:
proj_stem = [
    nn.Linear(224, 224),
    nn.ReLU()
]

patch_embed = PatchEmbedGeneric(proj_stem)

pos_helper = SpatioTemporal_posEmbeddingHelper(
    num_patches=num_patches,
    num_cls_tokens=num_cls_tokens,
    embed_dim=embed_dim,
    learnable=False  # or True if you want learnable positions
)

In [19]:
rgbt_processor = RGBTProcessor(rgbt_stem = patch_embed,
              pos_embed_fn = SpatioTemporal_posEmbeddingHelper, use_type_embed = True
              )

x = torch.randn(2, 3, 224, 224)

out = rgbt_processor(x)
out["trunk"]["tokens"]

tensor([[[ 2.3900,  0.2066,  0.0901],
         [ 1.0476, -1.0794,  1.0924],
         [ 1.6253, -0.0469,  1.5444],
         ...,
         [ 1.3145, -1.5348,  1.7872],
         [ 0.8960, -2.5253,  1.2743],
         [ 2.8765, -0.9186,  1.1537]],

        [[ 2.3900,  0.2066,  0.0901],
         [ 0.9397, -1.4308,  1.3125],
         [ 1.9686, -0.7751,  1.3938],
         ...,
         [ 2.7351, -1.5255,  0.1055],
         [ 1.4153,  0.4458, -0.9008],
         [ 3.8820,  0.4368,  0.0058]]], grad_fn=<AddBackward0>)

In [162]:
out["trunk"]["tokens"].shape

torch.Size([2, 50177, 3])

In [119]:
nn.init.normal(torch.rand(1,2,3))

  nn.init.normal(torch.rand(1,2,3))


tensor([[[-1.0355,  0.7191, -0.7407],
         [-1.3586, -0.0918, -1.3073]]])

In [107]:
RGBTProcessor(patch_embed)

3 (224, 224) 50176


RGBTProcessor()