In [1]:
import os
from functools import partial
from types import SimpleNamespace

import torch
from torch import nn    

In [2]:
from helpers import LearnableLogitScaling, SelectElement, Normalize, SelectEOSandProject
from multimodal_processors import TextPreprocessor, PadIm2Video, PatchEmbedGeneric, SpatioTemporal_posEmbeddingHelper, RGBTProcessor
from transformer import MultiheadAttention, SimpleTransformer

In [3]:
ModalityType = SimpleNamespace(
    VISION="vision",
    TEXT="text",
)

In [4]:
class ImageBindModel(nn.Module):
    def __init__(
        self,
        video_frames=2,
        kernel_size=(2, 14, 14),
        out_embed_dim=768,

        vision_embed_dim=1024,
        vision_num_blocks=24,
        vision_num_heads=16,

        text_embed_dim=768,
        text_num_blocks=12,
        text_num_heads=12,
    ):
        super().__init__()

        self.modality_preprocessors = self._create_modality_preprocessors(
            video_frames,
            vision_embed_dim,
            kernel_size,
            text_embed_dim,
        )

        self.modality_trunks = self._create_modality_trunks(
            vision_embed_dim,
            vision_num_blocks,
            vision_num_heads,
            text_embed_dim,
            text_num_blocks,
            text_num_heads,
        )

        self.modality_heads = self._create_modality_heads(
            out_embed_dim,
            vision_embed_dim,
            text_embed_dim,
        )

        self.modality_postprocessors = self._create_modality_postprocessors()

    def _create_modality_preprocessors(
        self,
        video_frames=2,
        vision_embed_dim=1024,
        kernel_size = (2, 14, 14),

        text_embed_dim=768,
    ):
        proj_stem = [
                PadIm2Video(ntimes=2, pad_type="repeat"), 
                nn.Conv3d(
                    in_channels=3,
                    kernel_size=kernel_size,
                    out_channels=vision_embed_dim, #
                    stride=kernel_size,
                    bias=False,
                )
            
            ]
        rgbt_stem = PatchEmbedGeneric(proj_stem = proj_stem)

        rgbt_preprocessor = RGBTProcessor(
            rgbt_stem = rgbt_stem,
            img_size = [3, video_frames, 224, 224],
            num_cls_token=1,
            pos_embed_fn = SpatioTemporal_posEmbeddingHelper,
        )

        text_preprocessor = TextPreprocessor(
            context_length = 77,
            vocab_size = 49408,
            embed_dim=text_embed_dim,
            causual_mask=True,
        )

        modality_preprocessors = {
            ModalityType.VISION: rgbt_preprocessor,
            ModalityType.TEXT: text_preprocessor,
        }

        return nn.ModuleDict(modality_preprocessors)

    def _create_modality_trunks(
        self,
        vision_embed_dim=1024,
        vision_num_blocks=24,
        vision_num_heads=16,

        text_embed_dim=768,
        text_num_blocks=12,
        text_num_heads=12,

    ):
        def instantiate_trunk(embed_dim, num_blocks, drop_path, num_heads, add_bias_kv, pre_transformer_ln):
            multihead_cls = partial(
                MultiheadAttention,
                embed_dim=embed_dim,
                num_heads=num_heads,
                bias=True,
                add_bias_kv=add_bias_kv,
            )
            
            simple_transformer = SimpleTransformer(
                embed_dim = embed_dim,
                num_blocks = num_blocks,
                ffn_dropout_rate=0.0,
                drop_path_rate = drop_path,
                attn_target = multihead_cls(),
                    
                # I already maked sure that the shape is aligned when using MultiheadAttention so no need of rearrangement now
                pre_transformer_layer = nn.LayerNorm(embed_dim) if pre_transformer_ln else nn.Identity(),
                post_transformer_layer = nn.Identity()
            )
            return simple_transformer
        
        modality_trunks = {}
        modality_trunks[ModalityType.VISION] = instantiate_trunk(embed_dim = vision_embed_dim, num_blocks = vision_num_blocks, num_heads = vision_num_heads, drop_path = 0.0, pre_transformer_ln=True, add_bias_kv=False,)
        modality_trunks[ModalityType.TEXT] = instantiate_trunk(embed_dim = text_embed_dim, num_blocks = text_num_blocks, num_heads = text_num_heads, drop_path = 0.0, pre_transformer_ln=False, add_bias_kv=False,)

        return nn.ModuleDict(modality_trunks)
    
    def _create_modality_heads(
        self,
        out_embed_dim,
        vision_embed_dim,
        text_embed_dim,
    ):
        modality_heads = {}

        modality_heads[ModalityType.VISION] = nn.Sequential(
            nn.LayerNorm(normalized_shape = vision_embed_dim, eps=1e-6),
            SelectElement(index = 0),
            nn.Linear(vision_embed_dim, out_embed_dim, bias = False)
        )

        modality_heads[ModalityType.TEXT] = SelectEOSandProject(
            proj = nn.Sequential(
                nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
                nn.Linear(text_embed_dim, out_embed_dim, bias=False),
            )
        )
        
        return nn.ModuleDict(modality_heads)

    def _create_modality_postprocessors(self):
        modality_postprocessors = {}

        modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
        modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
            Normalize(dim=-1), LearnableLogitScaling(learnable=True)
        )

        return nn.ModuleDict(modality_postprocessors)
    
    def forward(self, inputs):
        outputs = {}
        for modality_key, modality_value in inputs.items():
            reduce_list = (modality_value.ndim >= 5)     # Because video's ndim is 5 (B, C, T, H, W)
            if reduce_list:
                B, S = modality_value.shape[:2]
                modality_value = modality_value.reshape(B*S, *modality_value.shape[2:])
            
            if modality_value is not None:
                modality_value = self.modality_preprocessors[modality_key](**{modality_key: modality_value}) # Access the forward function of the classes

                trunk_inputs = modality_value["trunk"]
                head_inputs = modality_value["head"]
                
                modality_value = self.modality_trunks[modality_key](**trunk_inputs)                # Access the forward function of the classes
                modality_value = self.modality_heads[modality_key](modality_value, **head_inputs) # Access the forward function of the classes
                modality_value = self.modality_postprocessors[modality_key](modality_value)

            if reduce_list:
                modality_value = modality_value.reshape(B, S, -1)
                modality_value = modality_value.mean(dim=1)

            outputs[modality_key] = modality_value
            
        return outputs
        

In [None]:
vocab_size = 100
context_length = 77


text = torch.randint(0, vocab_size, (2, context_length)) # text: [2, 77]
image = torch.randn(2, 3, 224, 224)

inputs = {
    ModalityType.VISION: image,
    ModalityType.TEXT: text
}

with torch.no_grad():
    imagebindmodel = ImageBindModel()
    outputs = imagebindmodel(inputs)
outputs  # don't worry about the zeros we initialized it like that, this will update during training

{'vision': tensor([[-0.0387, -0.0006,  0.0154,  ..., -0.0126,  0.0062,  0.0371],
         [ 0.0207, -0.0473, -0.0040,  ...,  0.0461,  0.0240,  0.0076]]),
 'text': tensor([[ 0.1791,  0.7256,  0.8233,  ...,  0.0891, -0.3678, -0.0287],
         [-0.6651,  0.7141,  0.4609,  ...,  0.7731, -0.2154,  0.6157]])}

In [8]:
outputs['vision'] @ outputs['text'].T

tensor([[-0.1303, -0.1767],
        [ 0.0482, -0.2382]])

In [7]:
for name, param in imagebindmodel.named_parameters():
    print(name, param.data)

modality_preprocessors.vision.cls_tokens tensor([[[ 0.0143,  0.0389, -0.0428,  ...,  0.0120,  0.0046, -0.0320]]])
modality_preprocessors.vision.rgbt_stem.proj.1.weight tensor([[[[[-8.4567e-03, -1.2238e-02,  2.7205e-02,  ...,  1.8445e-02,
             9.7923e-03, -9.6885e-03],
           [-1.5277e-02,  8.6289e-03, -2.6051e-02,  ..., -1.9239e-03,
            -1.0554e-02,  2.7062e-02],
           [ 1.6029e-02, -2.3984e-02, -8.9512e-03,  ...,  9.7922e-03,
             2.8035e-02, -2.1071e-02],
           ...,
           [ 5.0586e-03,  2.8686e-02,  1.7529e-03,  ...,  1.6769e-02,
            -5.6412e-03,  1.5022e-02],
           [ 2.0523e-02, -3.6020e-04,  1.1603e-02,  ..., -5.3394e-03,
             2.2108e-02, -2.8708e-02],
           [ 3.6148e-03, -7.8189e-03, -5.9277e-03,  ...,  5.3170e-03,
             1.5358e-02, -2.0196e-03]],

          [[-2.3806e-02,  1.9807e-02,  2.7607e-02,  ..., -3.0831e-03,
            -8.7890e-03, -2.0583e-02],
           [ 8.1969e-04, -2.9938e-04,  2.5507e-02, 

In [10]:
imagebindmodel = ImageBindModel()
imagebindmodel.modality_preprocessors

ModuleDict(
  (vision): RGBTProcessor(
    (rgbt_stem): PatchEmbedGeneric(
      (proj): Sequential(
        (0): PadIm2Video()
        (1): Conv3d(3, 1024, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
      )
    )
    (pos_embed_helper): SpatioTemporal_posEmbeddingHelper()
  )
  (text): TextPreprocessor(
    (mask): tensor((77, 77), requires_grad=False)
    
    (token_embedding): Embedding(49408, 768)
  )
)

In [11]:
vocab_size = 100
context_length = 77
embed_dim = 768

# Sample input: batch of 1, padded or truncated to 77 tokens
text = torch.randint(0, vocab_size, (2, context_length))
text

tensor([[57, 10, 71, 24, 78, 45, 96, 92, 33, 51, 71, 88, 27, 99, 99, 36, 99, 99,
         27, 34, 62, 89, 83, 11, 67, 37, 95, 27, 76, 87, 97,  5, 26, 29, 16, 69,
         75, 84, 28, 47, 94, 82, 50, 41, 62,  2,  4,  8, 71, 61, 61,  1, 86, 79,
          2, 23, 64, 57, 60, 19, 66, 84, 15, 17, 25, 95, 56, 56, 45,  9, 21, 22,
         67, 74, 70, 58, 45],
        [ 0, 84, 70, 84, 10, 55,  5, 16, 49, 17, 25, 83, 21, 62, 26, 48, 17, 28,
         82, 39, 69,  9, 21, 62, 14,  7,  8, 31, 94,  5, 52,  8, 70, 98, 30, 22,
         26, 67, 29, 26, 77, 91, 94, 55, 11, 30, 71, 35, 38, 99,  6, 42, 72, 83,
         32, 83, 45, 62, 47, 65, 55, 94, 35,  9, 90, 33, 31, 64, 98, 22, 16, 75,
         14, 88, 72, 33, 83]])

In [12]:
imagebindmodel.modality_preprocessors['text'](**{'text': text})

{'trunk': {'tokens': tensor([[[-0.0278, -0.0147,  0.0142,  ...,  0.0379,  0.0231, -0.0007],
           [ 0.0036,  0.0190, -0.0072,  ...,  0.0214, -0.0359, -0.0055],
           [-0.0086,  0.0439,  0.0641,  ...,  0.0344,  0.0206, -0.0269],
           ...,
           [-0.0097, -0.0564,  0.0203,  ..., -0.0143,  0.0201, -0.0112],
           [-0.0503, -0.0397, -0.0054,  ..., -0.0036,  0.0030, -0.0224],
           [-0.0284,  0.0165,  0.0259,  ..., -0.0183, -0.0066,  0.0612]],
  
          [[-0.0011, -0.0180,  0.0150,  ...,  0.0263,  0.0206,  0.0030],
           [-0.0132,  0.0372, -0.0092,  ..., -0.0099, -0.0692, -0.0601],
           [-0.0291, -0.0221,  0.0676,  ...,  0.0514,  0.0277, -0.0457],
           ...,
           [-0.0138, -0.0175,  0.0171,  ..., -0.0866, -0.0165,  0.0325],
           [ 0.0122, -0.0165, -0.0106,  ..., -0.0047, -0.0123, -0.0070],
           [ 0.0167,  0.0090,  0.0091,  ..., -0.0222, -0.0303,  0.0238]]],
         grad_fn=<AddBackward0>),
  'attn_mask': tensor([[0., -inf,

In [13]:
imagebindmodel.modality_trunks

ModuleDict(
  (vision): SimpleTransformer(
    (pre_transformer_layer): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (post_transformer_layer): Identity()
    (blocks): Sequential(
      (0): BlockWithMasking(
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
        )
        (drop_path): Identity()
        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (layer_scale_gamma1): Identity()
        (layer_scale_gamma2): Identity()
      )
      (1): BlockWithMasking(
        (attn): MultiheadAttention(
          (out_proj): NonDynamical

In [14]:
imagebindmodel.modality_heads

ModuleDict(
  (vision): Sequential(
    (0): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
    (1): SelectElement()
    (2): Linear(in_features=1024, out_features=768, bias=False)
  )
  (text): SelectEOSandProject(
    (proj): Sequential(
      (0): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (1): Linear(in_features=768, out_features=768, bias=False)
    )
  )
)

In [15]:
imagebindmodel.modality_postprocessors

ModuleDict(
  (vision): Normalize()
  (text): Sequential(
    (0): Normalize()
    (1): LearnableLogitScaling(logit_scale_init=14.285714285714285,learnable=True, max_logit_scale=100)
  )
)