In [39]:
import torch
from torch import nn
from typing import Callable, List, Optional, Tuple
import math
from timm.models.layers import trunc_normal_
import numpy as np

In [2]:
class VerboseNNModule(nn.Module):
    
    @staticmethod
    def get_readable_tensor_representation(name: str, tensor: torch.Tensor):
        st = (
            "(" + name + "): " + "tensor(" + str(tuple(tensor[1].shape)) + ", requires_grad=" + str(tensor[1].requires_grad) + ")\n"
            )
        return st
    
    def extra_repr(self) -> str:
        named_modules = set()
        for p in self.named_modules():
            named_modules.update(p[0])
        named_modules = list(named_modules)

        string_repr = ""
        for p in self.named_parameters():
            name = p[0].split(".")[0]
            if name in named_modules:
                string_repr += self.get_readable_tensor_representation(name, p)
        
        for p in self.named_buffers():
            name = p[0].split(".")[0]
            string_repr += self.get_readable_tensor_representation(name, p)
        
        return string_repr

class MyLayer(VerboseNNModule):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 2)
        self.register_buffer("running_mean", torch.zeros(5))

# Instantiate and print the model
model = MyLayer()
print(model)

MyLayer(
  (running_mean): tensor((5,), requires_grad=False)
  
  (linear): Linear(in_features=1, out_features=2, bias=True)
)


In [3]:
def build_causal_attention_mask(context_length):
    mask = torch.empty(context_length, context_length, requires_grad=False)
    mask.fill_(float("-inf"))
    mask.triu_(1)
    return mask

build_causal_attention_mask(5)

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [5]:
class TextPreprocessor(VerboseNNModule):
    def __init__(self, vocab_size: int, context_length: int, embed_dim: int, causual_mask: bool, 
                 supply_seq_len_to_head: bool = True, init_param_style: str = "openclip"):
        """
        `vocab_size`: Number of tokens in your vocabulary.                 the number of words in your text, so we can map nn.Embedding
	    `context_length`: Maximum number of tokens per input sequence.     usually: 77
	    `embed_dim`: Dimensionality of each token embedding.               usually: 768
        """

        super().__init__()
        
        self.vocab_size = vocab_size
        self.context_length = context_length
        self.causual_mask = causual_mask
        self.embed_dim = embed_dim
        self.supply_seq_len_to_head = supply_seq_len_to_head
        
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(
            torch.empty(1, context_length, embed_dim)
        )
        if causual_mask:
            mask = build_causal_attention_mask(context_length)
            self.register_buffer("mask", mask) # register the mask as a buffer so it can be moved to the right device
        
        self.init_parameters(init_param_style)

    @torch.no_grad()
    def init_parameters(self, init_param_style = "openclip"):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.pos_embed, std=0.02)

        # I did'nt use init_param_style as I was too lazy to implment [CLS]
    
    def forward(self, text):
        token_text = self.token_embedding(text)
        token_text = token_text + self.pos_embed
        
        return_dict = {
            "trunk": {
                "tokens": token_text
            },
            "head": {},
        }

        if self.supply_seq_len_to_head:
            text_lengths = text.argmax(dim = -1)
            #  hacky and non-standard way of getting the sequence length.
            return_dict["head"] = {
                "seq_len": text_lengths,
            }
        if self.causual_mask:
            return_dict["trunk"].update({"attn_mask": self.mask})
        
        return return_dict
    
vocab_size = 100
context_length = 77
embed_dim = 768

# Sample input: batch of 1, padded or truncated to 77 tokens
text = torch.randint(0, vocab_size, (2, context_length))  # shape [1, 77]

text_processor = TextPreprocessor(
    vocab_size=vocab_size,
    context_length=context_length,
    embed_dim=embed_dim,
    causual_mask=True,
    supply_seq_len_to_head=True
)

out = text_processor(text)
print(out["trunk"]["tokens"].shape)  # ➜ [2, 77, 768]

torch.Size([2, 77, 768])


In [6]:
out["trunk"]["attn_mask"].shape

torch.Size([77, 77])

In [7]:
out["head"]["seq_len"]

tensor([31, 28])

In [10]:
print(text_processor)

TextPreprocessor(
  (mask): tensor((77, 77), requires_grad=False)
  
  (token_embedding): Embedding(100, 768)
)


In [50]:
class PatchEmbedGeneric(nn.Module):
    def __init__(self, proj_stem, norm_layer: Optional[Callable] = None):
        super().__init__()

        if len(proj_stem) > 1:
            self.proj_stem = nn.Sequential(*proj_stem)
        else:
            # Special case to be able to load pre-trained models that were
            # trained with a standard stem
            self.proj = proj_stem[0]
        self.norm_layer = norm_layer
    
    def get_patch_layout(self, image_size):
        with torch.no_grad():
            dummy_img = torch.zeros([1,]) + image_size   # 1, C, (T), H, W
            dummy_out = self.proj(dummy_img)
        
        embed_dim = dummy_out[1]                        # `embed_dim`    = C        
        patch_layout = tuple(dummy_out.shape[2:])             # `patch_layout` = (T), H, W       
        num_patches = np.prod(patch_layout)             # `num_patches`  = (T) * H * W       
        return embed_dim, patch_layout, num_patches

    def forward(self, x: torch.Tensor):
        x = x.flatten(2)                                # B, C, (T), H, W -> B, C, (T)*H*W
        x = x.transpose(1, 2)                           # B, C, (T)*H*W   -> B, (T)*H*W, C
        if self.norm_layer is not None:
            x = self.norm_layer(x)
        return x

## Testing
proj_stem = [
    nn.Conv3d(3, 16, kernel_size=3, stride=2, padding=1),  # Assume input is (B, 3, 8, 64, 64)
    nn.ReLU()
]

patch_embed = PatchEmbedGeneric(proj_stem)
x = torch.randn(2, 3, 8, 64, 64)

out = patch_embed(x)
out.shape
 

torch.Size([2, 32768, 3])