In [1]:
import os
from functools import partial
from types import SimpleNamespace

import torch
from torch import nn    

In [None]:
from helpers import LearnableLogitScaling, SelectElement, Normalize, SelectEOSandProject
from multimodal_processors import TextPreprocessor, PadIm2Video, PatchEmbedGeneric, SpatioTemporal_posEmbeddingHelper, RGBTProcessor
from transformer import MultiheadAttention, SimpleTransformer

In [None]:
ModalityType = SimpleNamespace(
    VISION="vision",
    TEXT="text",
)

In [5]:
class ImageBindModel(nn.Module):
    def __init__(
        self,
        video_frames=2,
        kernel_size=(2, 14, 14),
        out_embed_dim=768,

        vision_embed_dim=1024,
        vision_num_blocks=24,
        vision_num_heads=16,

        text_embed_dim=768,
        text_num_blocks=12,
        text_num_heads=12,
    ):
        super().__init__()

        self.modality_preprocessors = self._create_modality_preprocessors(
            video_frames,
            vision_embed_dim,
            kernel_size,
            text_embed_dim,
        )

        self.modality_trunks = self._create_modality_trunks(
            vision_embed_dim,
            vision_num_blocks,
            vision_num_heads,
            text_embed_dim,
            text_num_blocks,
            text_num_heads,
        )

        self.modality_heads = self._create_modality_heads(
            out_embed_dim,
            vision_embed_dim,
            text_embed_dim,
        )

        self.modality_postprocessors = self._create_modality_postprocessors()

    def _create_modality_preprocessors(
        self,
        video_frames=2,
        vision_embed_dim=1024,
        kernel_size=(2, 14, 14),

        text_embed_dim=768,
    ):
        rgbt_stem = PatchEmbedGeneric(
            proj_stem =
            [
                PadIm2Video(ntimes=2, pad_type="repeat"), 
                nn.Conv3d(
                    in_channels=3,
                    kernel_size=kernel_size,
                    out_channels=vision_embed_dim,
                    stride=kernel_size,
                    bias=False,
                )
            
            ]
        )

        rgbt_preprocessor = RGBTProcessor(
            rgbt_stem = rgbt_stem,
            img_size = [3, video_frames, 224, 224],
            num_cls_token=1,
            pos_embed_fn=partial(SpatioTemporal_posEmbeddingHelper, learnable=True),
        )

        text_preprocessor = TextPreprocessor(
            context_length = 77,
            vocab_size = 49408,
            embed_dim=text_embed_dim,
            causual_mask=True,
        )

        modality_preprocessors = {
            ModalityType.VISION: rgbt_preprocessor,
            ModalityType.TEXT: text_preprocessor,
        }

        return nn.ModuleDict(modality_preprocessors)

    def _create_modality_trunks(
        self,
        vision_embed_dim=1024,
        vision_num_blocks=24,
        vision_num_heads=16,

        text_embed_dim=768,
        text_num_blocks=12,
        text_num_heads=12,

    ):
        def instantiate_trunk(embed_dim, num_blocks, drop_path, num_heads, add_bias_kv, pre_transformer_ln):
            simple_transformer = SimpleTransformer(
                embed_dim = embed_dim,
                num_blocks = num_blocks,
                ffn_dropout_rate=0.0,
                drop_path_rate = drop_path,
                attn_target = 
                    MultiheadAttention(
                    embed_dim=embed_dim,
                    num_heads=num_heads,
                    bias=True,
                    add_bias_kv=add_bias_kv,
                ),
                # I already maked sure that the shape is aligned when using MultiheadAttention so no need of rearrangement now
                pre_transformer_layer = nn.LayerNorm(embed_dim) if pre_transformer_ln else nn.Identity(),
                post_transformer_layer = nn.Identity()
            )
            return simple_transformer
        
        modality_trunks = {}
        modality_trunks[ModalityType.VISION] = instantiate_trunk(embed_dim = vision_embed_dim, num_blocks = vision_num_blocks, num_heads = vision_num_heads, drop_path = 0.0, pre_transformer_ln=True, add_bias_kv=False,)
        modality_trunks[ModalityType.TEXT] = instantiate_trunk(embed_dim = text_embed_dim, num_blocks = text_num_blocks, num_heads = text_num_heads, drop_path = 0.0, pre_transformer_ln=False, add_bias_kv=False,)
        
        return nn.ModuleDict(modality_trunks)
    
    def _create_modality_heads(
        self,
        out_embed_dim,
        vision_embed_dim,
        text_embed_dim,
    ):
        modality_heads = {}

        modality_heads[ModalityType.VISION] = nn.Sequential(
            nn.LayerNorm(normalized_shape = vision_embed_dim, eps=1e-6),
            SelectElement(index = 0),
            nn.Linear(vision_embed_dim, out_embed_dim, bias = False)
        )

        modality_heads[ModalityType.TEXT] = SelectEOSandProject(
            proj = nn.Sequential(
                nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
                nn.Linear(text_embed_dim, out_embed_dim, bias=False),
            )
        )
        
        return nn.ModuleDict(modality_heads)

    def _create_modality_postprocessors(self):
        modality_postprocessors = {}

        modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
        modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
            Normalize(dim=-1), LearnableLogitScaling(learnable=True)
        )

        return nn.ModuleDict(modality_postprocessors)
    
    def forward(self, inputs):
        outputs = {}
        for modality_key, modality_value in inputs.items():
            reduce_list = (modality_value.ndim >= 5)     # Because video's ndim is 5 (B, T, C, H, W)
            if reduce_list:
                B, S = modality_value.shape[:2]
                modality_value = modality_value.reshape(B*S, *modality_value.shape[2:])
            
            if modality_value is not None:
                modality_value = self.modality_preprocessors[modality_key](**{modality_key: modality_value}) # Access the forward function of the claasses

                trunk_inputs = modality_value["trunk"]
                head_inputs = modality_value["head"]
                modality_value = self.modality_trunks[modality_key](**trunk_inputs)                # Access the forward function of the claasses
                modality_value = self.modality_heads[modality_key](modality_value, **head_inputs) # Access the forward function of the claasses

            if reduce_list:
                modality_value = modality_value.reshape(B, S, -1)
                modality_value = modality_value.mean(dim=1)

            outputs[modality_key] = modality_value
            
        return outputs
        

In [5]:
imagebindmodel = ImageBindModel()
imagebindmodel.modality_preprocessors

ModuleDict(
  (vision): RGBTProcessor(
    (rgbt_stem): PatchEmbedGeneric(
      (proj): Sequential(
        (0): PadIm2Video()
        (1): Conv3d(3, 1024, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
      )
    )
    (pos_embed_helper): SpatioTemporal_posEmbeddingHelper()
  )
  (text): TextPreprocessor(
    (mask): tensor((77, 77), requires_grad=False)
    
    (token_embedding): Embedding(49408, 768)
  )
)

In [6]:
vocab_size = 100
context_length = 77
embed_dim = 768

# Sample input: batch of 1, padded or truncated to 77 tokens
text = torch.randint(0, vocab_size, (2, context_length))
text

tensor([[82, 17, 70, 60, 14, 99, 52, 13,  7, 91, 43, 80, 68, 16, 70, 22, 13, 14,
         10, 48, 95, 14, 94, 35, 47, 61, 55, 43, 14, 74,  9, 29, 91, 11, 50, 20,
         91, 90, 86, 30, 17, 86,  5, 93, 88, 12, 26, 73, 14, 59,  1, 49, 28, 90,
         70, 79, 38, 81, 84, 50, 28, 68, 42, 94, 75,  2, 70, 10, 18, 80, 35,  5,
         17, 43, 86, 26, 96],
        [16,  5, 90, 40, 60, 46, 21, 51, 88, 52, 68, 32, 58, 95, 28, 71,  2, 40,
         87, 25, 80,  3, 67, 30, 58, 64, 84, 55, 26, 15, 56,  9, 38, 57, 87, 41,
         30, 72, 86, 84, 17, 17, 51, 71, 18, 86, 58,  1, 95, 61, 41,  2,  3, 58,
         43, 66, 26, 74, 57, 19,  0, 99, 92,  7, 22, 42,  8, 65, 62, 81, 21, 74,
         67, 13, 19, 80, 97]])

In [7]:
imagebindmodel.modality_preprocessors['text'](**{'text': text})

{'trunk': {'tokens': tensor([[[ 0.0079, -0.0165, -0.0300,  ..., -0.0457, -0.0228,  0.0185],
           [ 0.0387, -0.0072, -0.0253,  ...,  0.0064,  0.0061,  0.0076],
           [-0.0428, -0.0018, -0.0102,  ..., -0.0243, -0.0410, -0.0098],
           ...,
           [-0.0368, -0.0145,  0.0055,  ..., -0.0097, -0.0285,  0.0355],
           [-0.0196,  0.0123,  0.0011,  ..., -0.0122,  0.0154,  0.0118],
           [ 0.0130, -0.0391, -0.0155,  ...,  0.0330, -0.0519, -0.0448]],
  
          [[ 0.0168, -0.0265, -0.0327,  ..., -0.0205,  0.0125,  0.0087],
           [ 0.0185,  0.0153,  0.0095,  ...,  0.0276, -0.0311, -0.0119],
           [ 0.0086,  0.0456,  0.0010,  ...,  0.0005, -0.0140, -0.0327],
           ...,
           [-0.0003, -0.0019, -0.0004,  ...,  0.0419, -0.0122,  0.0537],
           [-0.0098, -0.0520,  0.0379,  ...,  0.0140,  0.0317,  0.0182],
           [ 0.0003, -0.0311, -0.0242,  ...,  0.0317,  0.0162,  0.0165]]],
         grad_fn=<AddBackward0>),
  'attn_mask': tensor([[0., -inf,

In [8]:
imagebindmodel.modality_trunks

ModuleDict(
  (vision): SimpleTransformer(
    (pre_transformer_layer): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (post_transformer_layer): Identity()
    (blocks): Sequential(
      (0): BlockWithMasking(
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
        )
        (drop_path): Identity()
        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (layer_scale_gamma1): Identity()
        (layer_scale_gamma2): Identity()
      )
      (1): BlockWithMasking(
        (attn): MultiheadAttention(
          (out_proj): NonDynamical

In [9]:
imagebindmodel.modality_heads

ModuleDict(
  (vision): Sequential(
    (0): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
    (1): SelectElement()
    (2): Linear(in_features=1024, out_features=768, bias=False)
  )
  (text): SelectEOSandProject(
    (proj): Sequential(
      (0): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (1): Linear(in_features=768, out_features=768, bias=False)
    )
  )
)

In [10]:
imagebindmodel.modality_postprocessors

ModuleDict(
  (vision): Normalize()
  (text): Sequential(
    (0): Normalize()
    (1): LearnableLogitScaling(logit_scale_init=14.285714285714285,learnable=True, max_logit_scale=100)
  )
)

In [25]:
vision_embed_dim = 1024
kernel_size = (2, 14, 14)
                     
rgbt_stem = PatchEmbedGeneric(
            proj_stem =
            [
                PadIm2Video(ntimes=2, pad_type="repeat"), 
                nn.Conv3d(
                    in_channels=3,
                    kernel_size=kernel_size,
                    out_channels=vision_embed_dim,
                    stride=kernel_size,
                    bias=False,
                )
            
            ]
        )

rgbt_stem.get_patch_layout([3, 3, 224, 224])

(1024, (1, 16, 16), 256)

In [None]:
vocab_size = 100
context_length = 4
embed_dim = 8


text = torch.randint(0, vocab_size, (2, context_length)) # text: [2, 77]
image = torch.randn(2, 1024, 256, 224, 224)

inputs = {
    ModalityType.VISION: image,
    ModalityType.TEXT: text
}

imagebindmodel = ImageBindModel()
outputs = imagebindmodel(inputs)

In [None]:
imagebindmodel.modality_trunks['vision']

SimpleTransformer(
  (pre_transformer_layer): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (post_transformer_layer): Identity()
  (blocks): Sequential(
    (0): BlockWithMasking(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
      )
      (drop_path): Identity()
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (layer_scale_gamma1): Identity()
      (layer_scale_gamma2): Identity()
    )
    (1): BlockWithMasking(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)


In [None]:
outputs['text'].shape

torch.Size([2, 768])