In [1]:
import os
from functools import partial
from types import SimpleNamespace

import torch
from torch import nn

In [3]:
from helpers import LearnableLogitScaling, SelectElement, Normalize
from multimodal_processors import TextPreprocessor, Im2Video, PadIm2Video, PatchEmbedGeneric, SpatioTemporal_posEmbeddingHelper, RGBTProcessor
from transformer import MultiheadAttention, SimpleTransformer

In [17]:
ModalityType = SimpleNamespace(
    VISION="vision",
    TEXT="text",
)

In [None]:
class ImageBindModel(nn.Module):
    def __init__(
        self,
        video_frames=2,
        kernel_size=(2, 14, 14),
        out_embed_dim=768,

        vision_embed_dim=1024,
        vision_num_blocks=24,
        vision_num_heads=16,

        text_embed_dim=768,
        text_num_blocks=12,
        text_num_heads=12,
    ):
        super().__init__()

        self.modality_preprocessors = self._create_modality_preprocessors(
            video_frames,
            vision_embed_dim,
            kernel_size,
            text_embed_dim,
        )

    def _create_modality_preprocessors(
        self,
        video_frames=2,
        vision_embed_dim=1024,
        kernel_size=(2, 14, 14),

        text_embed_dim=768,
    ):
        rgbt_stem = PatchEmbedGeneric(
            proj_stem =
            [
                PadIm2Video(ntimes=2, pad_type="repeat"), 
                nn.Conv3d(
                    in_channels=3,
                    kernel_size=kernel_size,
                    out_channels=vision_embed_dim,
                    stride=kernel_size,
                    bias=False,
                )
            
            ]
        )

        rgbt_preprocessor = RGBTProcessor(
            rgbt_stem = rgbt_stem,
            img_size = [3, video_frames, 224, 224],
            num_cls_token=1,
            pos_embed_fn=partial(SpatioTemporal_posEmbeddingHelper, learnable=True),
        )

        text_preprocessor = TextPreprocessor(
            context_length = 77,
            vocab_size = 49408,
            embed_dim=text_embed_dim,
            causual_mask=True,
        )

        modality_preprocessors = {
            ModalityType.VISION: rgbt_preprocessor,
            ModalityType.TEXT: text_preprocessor,
        }

        return nn.ModuleDict(modality_preprocessors)

    def _create_modality_trunks(
        self,
        vision_embed_dim=1024,
        vision_num_blocks=24,
        vision_num_heads=16,

        text_embed_dim=768,
        text_num_blocks=12,
        text_num_heads=12,

    ):
        

In [22]:
imagebindmodel = ImageBindModel()
imagebindmodel.modality_preprocessors

ModuleDict(
  (vision): RGBTProcessor(
    (rgbt_stem): PatchEmbedGeneric(
      (proj): Sequential(
        (0): PadIm2Video()
        (1): Conv3d(3, 1024, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
      )
    )
    (pos_embed_helper): SpatioTemporal_posEmbeddingHelper()
  )
  (text): TextPreprocessor(
    (mask): tensor((77, 77), requires_grad=False)
    
    (token_embedding): Embedding(49408, 768)
  )
)