In [1]:
import os
from functools import partial
from types import SimpleNamespace

import torch
from torch import nn

In [8]:
from helpers import LearnableLogitScaling, SelectElement, Normalize, SelectEOSandProject
from multimodal_processors import TextPreprocessor, Im2Video, PadIm2Video, PatchEmbedGeneric, SpatioTemporal_posEmbeddingHelper, RGBTProcessor
from transformer import MultiheadAttention, SimpleTransformer

In [3]:
ModalityType = SimpleNamespace(
    VISION="vision",
    TEXT="text",
)

In [12]:
class ImageBindModel(nn.Module):
    def __init__(
        self,
        video_frames=2,
        kernel_size=(2, 14, 14),
        out_embed_dim=768,

        vision_embed_dim=1024,
        vision_num_blocks=24,
        vision_num_heads=16,

        text_embed_dim=768,
        text_num_blocks=12,
        text_num_heads=12,
    ):
        super().__init__()

        self.modality_preprocessors = self._create_modality_preprocessors(
            video_frames,
            vision_embed_dim,
            kernel_size,
            text_embed_dim,
        )

        self.modality_trunks = self._create_modality_trunks(
            vision_embed_dim,
            vision_num_blocks,
            vision_num_heads,
            text_embed_dim,
            text_num_blocks,
            text_num_heads,
        )

        self.modality_heads = self._create_modality_heads(
            out_embed_dim,
            vision_embed_dim,
            text_embed_dim,
        )

    def _create_modality_preprocessors(
        self,
        video_frames=2,
        vision_embed_dim=1024,
        kernel_size=(2, 14, 14),

        text_embed_dim=768,
    ):
        rgbt_stem = PatchEmbedGeneric(
            proj_stem =
            [
                PadIm2Video(ntimes=2, pad_type="repeat"), 
                nn.Conv3d(
                    in_channels=3,
                    kernel_size=kernel_size,
                    out_channels=vision_embed_dim,
                    stride=kernel_size,
                    bias=False,
                )
            
            ]
        )

        rgbt_preprocessor = RGBTProcessor(
            rgbt_stem = rgbt_stem,
            img_size = [3, video_frames, 224, 224],
            num_cls_token=1,
            pos_embed_fn=partial(SpatioTemporal_posEmbeddingHelper, learnable=True),
        )

        text_preprocessor = TextPreprocessor(
            context_length = 77,
            vocab_size = 49408,
            embed_dim=text_embed_dim,
            causual_mask=True,
        )

        modality_preprocessors = {
            ModalityType.VISION: rgbt_preprocessor,
            ModalityType.TEXT: text_preprocessor,
        }

        return nn.ModuleDict(modality_preprocessors)

    def _create_modality_trunks(
        self,
        vision_embed_dim=1024,
        vision_num_blocks=24,
        vision_num_heads=16,

        text_embed_dim=768,
        text_num_blocks=12,
        text_num_heads=12,

    ):
        def instantiate_trunk(embed_dim, num_blocks, drop_path, num_heads, add_bias_kv, pre_transformer_ln):
            simple_transformer = SimpleTransformer(
                embed_dim = embed_dim,
                num_blocks = num_blocks,
                ffn_dropout_rate=0.0,
                drop_path_rate = drop_path,
                attn_target = 
                    MultiheadAttention(
                    embed_dim=embed_dim,
                    num_heads=num_heads,
                    bias=True,
                    add_bias_kv=add_bias_kv,
                ),
                # I already maked sure that the shape is aligned when using MultiheadAttention so no need of rearrangement now
                pre_transformer_layer = nn.LayerNorm(embed_dim) if pre_transformer_ln else nn.Identity(),
                post_transformer_layer = nn.Identity()
            )
            return simple_transformer
        
        modality_trunks = {}
        modality_trunks[ModalityType.VISION] = instantiate_trunk(embed_dim = vision_embed_dim, num_blocks = vision_num_blocks, num_heads = vision_num_heads, drop_path = 0.0, pre_transformer_ln=True, add_bias_kv=False,)
        modality_trunks[ModalityType.TEXT] = instantiate_trunk(embed_dim = text_embed_dim, num_blocks = text_num_blocks, num_heads = text_num_heads, drop_path = 0.0, pre_transformer_ln=False, add_bias_kv=False,)
        
        return nn.ModuleDict(modality_trunks)
    
    def _create_modality_heads(
        self,
        out_embed_dim,
        vision_embed_dim,
        text_embed_dim,
    ):
        modality_heads = {}

        modality_heads[ModalityType.VISION] = nn.Sequential(
            nn.LayerNorm(normalized_shape = vision_embed_dim, eps=1e-6),
            SelectElement(index = 0),
            nn.Linear(vision_embed_dim, out_embed_dim, bias = False)
        )

        modality_heads[ModalityType.TEXT] = SelectEOSandProject(
            proj = nn.Sequential(
                nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
                nn.Linear(text_embed_dim, out_embed_dim, bias=False),
            )
        )
        
        return nn.ModuleDict(modality_heads)

In [13]:
imagebindmodel = ImageBindModel()
imagebindmodel.modality_preprocessors

ModuleDict(
  (vision): RGBTProcessor(
    (rgbt_stem): PatchEmbedGeneric(
      (proj): Sequential(
        (0): PadIm2Video()
        (1): Conv3d(3, 1024, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
      )
    )
    (pos_embed_helper): SpatioTemporal_posEmbeddingHelper()
  )
  (text): TextPreprocessor(
    (mask): tensor((77, 77), requires_grad=False)
    
    (token_embedding): Embedding(49408, 768)
  )
)

In [14]:
imagebindmodel.modality_trunks

ModuleDict(
  (vision): SimpleTransformer(
    (pre_transformer_layer): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (post_transformer_layer): Identity()
    (blocks): Sequential(
      (0): BlockWithMasking(
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
        )
        (drop_path): Identity()
        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (layer_scale_gamma1): Identity()
        (layer_scale_gamma2): Identity()
      )
      (1): BlockWithMasking(
        (attn): MultiheadAttention(
          (out_proj): NonDynamical

In [15]:
imagebindmodel.modality_heads

ModuleDict(
  (vision): Sequential(
    (0): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
    (1): SelectElement()
    (2): Linear(in_features=1024, out_features=768, bias=False)
  )
  (text): SelectEOSandProject(
    (proj): Sequential(
      (0): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (1): Linear(in_features=768, out_features=768, bias=False)
    )
  )
)