## RegionTransformer

In [50]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer

class RegionTransformer(nn.Module):
    def __init__(self, d_model=1152, nhead=8, num_self_layers=6, num_cross_layers=4, max_regions=15, max_tokens_img=800):
        super().__init__()
        self.d_model = d_model
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

        # Positional embeddings
        self.pos_embed_region = nn.Parameter(torch.randn(1, 2 * max_regions + 1, d_model))
        self.pos_embed_image = nn.Parameter(torch.randn(1, max_tokens_img, d_model))

        # Transformer encoders
        self.region_encoder = TransformerEncoder(
            TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True),
            num_layers=num_self_layers
        )
        self.cross_decoder = TransformerDecoder(
            TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True),
            num_layers=num_cross_layers
        )

    def forward(self, Freg_rgb, Freg_depth, Fimg_rgb):
        B, Nr, F = Freg_rgb.shape
        _, Nt, _ = Fimg_rgb.shape

        # Concatenate region features + CLS token
        Freg = torch.cat([Freg_rgb, Freg_depth], dim=1)       # (B, 2*Nr, F)
        cls = self.cls_token.expand(B, 1, F)                  # (B, 1, F)
        Freg = torch.cat([cls, Freg], dim=1)                  # (B, 2*Nr+1, F)
        Freg += self.pos_embed_region[:, :Freg.size(1), :]    # Add pos emb; Freg.size(1) maybe < max-ntokens

        # Region encoder
        Freg_encoded = self.region_encoder(Freg)              # (B, 2*Nr+1, F)

        # Add position embedding for image features
        Fimg_rgb_pos = Fimg_rgb + self.pos_embed_image[:, :Nt, :]  # (B, Nt, F)

        # Cross attention
        X = self.cross_decoder(Freg_encoded, Fimg_rgb_pos)    # (B, 2*Nr+1, F)

        CLS = X[:, 0, :]                                      # (B, F)
        region_tokens = X[:, 1:, :].chunk(2, dim=1)           # [(B, Nr, F), (B, Nr, F)]
        Freg_rgb_trans, Freg_depth_trans = region_tokens

        return CLS, Freg_rgb_trans, Freg_depth_trans

Freg_rgb = torch.randn(1, 5, 1152)
Freg_depth = torch.randn(1, 5, 1152)
Fimg_rgb = torch.randn(1, 792, 1152)

model = RegionTransformer()
CLS, Freg_rgb_trans, Freg_depth_trans = model(Freg_rgb, Freg_depth, Fimg_rgb)
"""
CLS: cho bài toán phân loại/hồi quy từ đặc trưng gobal của regions
Freg_rgb_trans, Freg_depth_trans: cho qua RGBProjector và DepthProjector
"""
print(CLS.shape)
print(Freg_rgb_trans.shape)
print(Freg_depth_trans.shape)

torch.Size([1, 1152])
torch.Size([1, 5, 1152])
torch.Size([1, 5, 1152])


## Original RegionExtractor

In [1]:
import os
import os.path as osp
import re
import sys

import einops
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel


class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x


class MaskPooling(nn.Module):
    def __init__(self, mask_threshold=0.5):
        super().__init__()
        self.mask_threshold = mask_threshold

    def forward(self, x, mask_list, return_list=False, return_mask=False):
        """
        Args:
            x: [B, (HW), C]
            mask_list: List( tensor[M, IH, IW] )
        """
        batch_size = x.size(0)
        if mask_list is None:
            mask_list = [None for i in range(batch_size)]

        output = []
        attn_mask_list = []
        for i in range(batch_size):
            x_len = x.size(1)
            mask = mask_list[i]
            if mask is None:
                output.append(None)
                attn_mask_list.append(None)
            else:
                # resize mask from image shape to feature map shape
                mask_hw = mask.size(-1) * mask.size(-2)
                scale_factor = (x_len / mask_hw) ** 0.5

                mask = mask.detach()
                mask = mask.float()[None, ...]
                mask = nn.functional.interpolate(mask, scale_factor=scale_factor, mode="bilinear")
                mask = mask.to(x.dtype)
                mask = mask[0]
                feature = x[i]

                denorm = mask.sum(dim=(-1, -2)) + 1e-8  # M
                denorm = denorm.unsqueeze(-1)  # M, 1

                mask = mask.flatten(start_dim=1)  # M, H, W -> M, HW

                attn_mask_list.append((mask > self.mask_threshold).to(mask.dtype))  # M, HW

                mask_pooled_x = torch.einsum(
                    "lc,ml->mc",
                    feature,
                    mask / denorm,
                )
                # mc output
                output.append(mask_pooled_x)

        if return_list:
            if return_mask:
                return output, attn_mask_list
            return output
        else:
            # FIXME: Not support Nonetype
            output = torch.cat(output)
            return output


def get_feature_refinement_module(vision_hidden_size, feature_refinement_type="deconv2x"):
    deconv_match = re.match(r"^deconv(\d+)x$", feature_refinement_type)
    if deconv_match:
        deconv_depth = int(deconv_match.group(1))
        modules = []
        for i in range(deconv_depth - 1):
            modules.append(nn.ConvTranspose2d(vision_hidden_size, vision_hidden_size, kernel_size=2, stride=2))
            modules.append(LayerNorm2d(vision_hidden_size))
            modules.append(nn.GELU())
        modules.append(nn.ConvTranspose2d(vision_hidden_size, vision_hidden_size, kernel_size=2, stride=2))
        modules.append(nn.GELU())

        return nn.Sequential(*modules)

    raise ValueError(f"Unknown feature refinement type: {feature_refinement_type}")


class RegionExtractorConfig(PretrainedConfig):
    model_type = "region_extractor"

    def __init__(self, region_extractor_type: str = None, **kwargs):
        super().__init__()
        self.region_extractor_type = region_extractor_type


class RegionExtractor(PreTrainedModel):
    config_class = RegionExtractorConfig

    def __init__(self, region_extractor_cfg: RegionExtractorConfig, config: PretrainedConfig):
        super().__init__(region_extractor_cfg)
        region_extractor_type = region_extractor_cfg.region_extractor_type

        if region_extractor_type == "regiongpt":
            self.mask_pooling = MaskPooling()
            self.feature_refinement_module = get_feature_refinement_module(config.mm_hidden_size)
            # TODO: hardcoded pooling size here, should be inside cfg
            self.ada_pooling = nn.AdaptiveAvgPool2d(27)
            self.rgb_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
            self.depth_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
        elif region_extractor_type == "duplicate":
            self.mask_pooling = MaskPooling()
            self.rgb_projector = None
            self.depth_projector = None
        elif region_extractor_type == "duplicate_deconv":
            self.feature_refinement_module = get_feature_refinement_module(config.mm_hidden_size)
            self.ada_pooling = nn.AdaptiveAvgPool2d(27)
            self.mask_pooling = MaskPooling()
            self.rgb_projector = None
            self.depth_projector = None

    def feature_refinement(self, tower_features):
        HW = tower_features.shape[1]
        tower_features = einops.rearrange(tower_features, "N (H W) C -> N C H W", H=int(HW**0.5))
        hres_tower_features = self.feature_refinement_module(tower_features)
        # local feature branch
        hres_tower_features_flatten = einops.rearrange(hres_tower_features, "N C H W -> N (H W) C")

        # global feature branch
        ada_image_feature = self.ada_pooling(hres_tower_features)
        lres_tower_features_flatten = einops.rearrange(ada_image_feature, "N C H W -> N (H W) C")
        return hres_tower_features_flatten, lres_tower_features_flatten

    def extract_region_features(self, hres_tower_features, masks, connector):
        # assume is already flattened -> 'N (H W) C'
        if self.config.region_extractor_type == "regiongpt":
            mask_embeds = self.mask_pooling(hres_tower_features, masks, return_list=True)
            _mask_embeds = []
            for mask_embed in mask_embeds:
                if mask_embed is None:
                    _mask_embeds.append(None)
                else:
                    _mask_embeds.append(
                        connector(mask_embed)
                    )

        elif self.config.region_extractor_type in ["duplicate", "duplicate_deconv"]:
            raise NotImplementedError(f"{self.config.region_extractor_type} not implemented")

        mask_embeds = _mask_embeds

        return mask_embeds

    def forward(self, 
                image_features, # H
                depth_features,
                masks,
                *args, **kwargs):
        
        mask_embeds = self.extract_region_features(image_features, masks, self.rgb_projector)
        if depth_features is not None:
            depth_embeds = self.extract_region_features(depth_features, masks, self.depth_projector)
        else:
            depth_embeds = None
        return mask_embeds, depth_embeds
        
# --------------------------------------
# assume batch_size = 2
# input: rgb_image, depth_image, masks 
# output: rgb_region_emb, depth_region_emb
Fimg_rgb = torch.randn(2, 729, 1152) # (B, L, F)    # Output of visusal encoder
Fimg_depth = torch.randn(2, 729, 1152) # (B, L, F)  # Output of visual encoder
masks = [                                           # (B, n_masks, W, H)   # Masks
            torch.randn(5, 384, 384),               # 5 masks of sample 1 (assume batch size 2)
            torch.randn(3, 384, 384)                # 3 masks of sample 2 (assume batch size 2)
        ] 

# region_extractor model
region_extractor_cfg = RegionExtractorConfig(region_extractor_type="regiongpt")
config = PretrainedConfig()
config.mm_hidden_size = 1152
config.hidden_size = 4096

region_extractor = RegionExtractor(region_extractor_cfg, config)

# forward
hres_tower_features, lres_tower_features = region_extractor.feature_refinement(Fimg_rgb)       # get high feature and low feature
# hres_tower_features: (B, 11664, 1152)
# lres_tower_features: (B, 729, 1152)


reg_rgb_emb, reg_depth_emb = region_extractor(hres_tower_features, Fimg_depth, masks) # mask_pooling + projector
# reg_rgb_emb: list( tensor[n_masks, F])
# reg_depth_emb: list( tensor[n_masks, F])

print(f"For rgb region:   {len(reg_rgb_emb), reg_rgb_emb[0].shape, reg_rgb_emb[1].shape}") # (B, n_masks, F)
print(f"For depth region: {len(reg_depth_emb), reg_depth_emb[0].shape, reg_depth_emb[1].shape}") # (B, n_masks, F)


For rgb region:   (2, torch.Size([5, 4096]), torch.Size([3, 4096]))
For depth region: (2, torch.Size([5, 4096]), torch.Size([3, 4096]))


## Modify RegionExtractor to add RegionTransformer

In [87]:
# class RegionExtractor(PreTrainedModel):
#     config_class = RegionExtractorConfig

#     def __init__(self, region_extractor_cfg: RegionExtractorConfig, config: PretrainedConfig):
#         super().__init__(region_extractor_cfg)
#         region_extractor_type = region_extractor_cfg.region_extractor_type

#         if region_extractor_type == "regiongpt":
#             self.mask_pooling = MaskPooling()
#             self.feature_refinement_module = get_feature_refinement_module(config.mm_hidden_size)
#             # TODO: hardcoded pooling size here, should be inside cfg
#             self.ada_pooling = nn.AdaptiveAvgPool2d(27)
#             self.rgb_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
#             self.depth_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
            
#             # ============= Add Region Transformer ============
#             self.use_region_transformer = True
#             self.Region_Transformer = RegionTransformer(d_model=1152, nhead=8, num_self_layers=6, num_cross_layers=4, max_regions=15, max_tokens_img=800)
#             # =================================================
            
#         elif region_extractor_type == "duplicate":
#             self.mask_pooling = MaskPooling()
#             self.rgb_projector = None
#             self.depth_projector = None
#         elif region_extractor_type == "duplicate_deconv":
#             self.feature_refinement_module = get_feature_refinement_module(config.mm_hidden_size)
#             self.ada_pooling = nn.AdaptiveAvgPool2d(27)
#             self.mask_pooling = MaskPooling()
#             self.rgb_projector = None
#             self.depth_projector = None

#     def feature_refinement(self, tower_features):
#         HW = tower_features.shape[1]
#         tower_features = einops.rearrange(tower_features, "N (H W) C -> N C H W", H=int(HW**0.5))
#         hres_tower_features = self.feature_refinement_module(tower_features)
#         # local feature branch
#         hres_tower_features_flatten = einops.rearrange(hres_tower_features, "N C H W -> N (H W) C")

#         # global feature branch
#         ada_image_feature = self.ada_pooling(hres_tower_features)
#         lres_tower_features_flatten = einops.rearrange(ada_image_feature, "N C H W -> N (H W) C")
#         return hres_tower_features_flatten, lres_tower_features_flatten

#     def extract_region_features(self, hres_tower_features, masks, connector):
#         # assume is already flattened -> 'N (H W) C'
#         if self.config.region_extractor_type == "regiongpt":
#             mask_embeds = self.mask_pooling(hres_tower_features, masks, return_list=True)
#             _mask_embeds = []
#             for mask_embed in mask_embeds:
#                 if mask_embed is None:
#                     _mask_embeds.append(None)
#                 else:
#                     # ----------- modify -----------
#                     if not self.use_region_transformer:
#                         _mask_embeds.append(connector(mask_embed))
#                     else:
#                         _mask_embeds.append(mask_embed) # remove connector, later
#                     # ------------------------------

#         elif self.config.region_extractor_type in ["duplicate", "duplicate_deconv"]:
#             raise NotImplementedError(f"{self.config.region_extractor_type} not implemented")

#         mask_embeds = _mask_embeds

#         return mask_embeds

#     def forward(self,
#                 lres_tower_features, # add for region cross_decoder
#                 hres_tower_features, 
#                 depth_features,
#                 masks,
#                 *args, **kwargs
#                ):
        
#         Freg_rgbs = self.extract_region_features(hres_tower_features, masks, self.rgb_projector)      # rgb_region_emb list( tensor[n_masks, 1152] )
#         if depth_features is not None:
#             Freg_depths = self.extract_region_features(depth_features, masks, self.depth_projector)   # depth_region_emb list( tensor[n_masks, 1152] )
#         else:
#             Freg_depths = None
#         if not self.use_region_transformer:
#             return (None, None, None, Freg_rgbs, Freg_depths)
            
#         else:
#         # ========== Modify forward when using RegionTransformer ==========
#             CLS = []
#             Freg_rgb_trans = []
#             Freg_depth_trans = []
#             Freg_rgb_tran_projs = []
#             Freg_depth_tran_projs = []
                
#             for Freg_rgb, Freg_depth, lres_tower_feature in zip(Freg_rgbs, Freg_depths, lres_tower_features):
#                 # print(f"Freg_rgb: {Freg_rgb[None].shape}")
#                 # print(f"Freg_depth: {Freg_depth[None].shape}")
#                 # print(f"lres_tower_feature: {lres_tower_feature[None].shape}")
#                 Cls, Freg_rgb_tran, Freg_depth_tran = self.Region_Transformer(Freg_rgb[None], Freg_depth[None], lres_tower_feature[None])
                
#                 Freg_rgb_tran_proj = self.rgb_projector(Freg_rgb_tran)
#                 Freg_depth_tran_proj = self.depth_projector(Freg_depth_tran)

#                 CLS.append(Cls)
#                 Freg_rgb_trans.append(Freg_rgb_tran)
#                 Freg_depth_trans.append(Freg_depth_tran)
#                 Freg_rgb_tran_projs.append(Freg_rgb_tran_proj)
#                 Freg_depth_tran_projs.append(Freg_depth_tran_proj)
            
#             # return mask_embeds, depth_embeds
#             return (torch.stack(CLS),         # CLS embed for each sample
#                     Freg_rgb_trans,           # rgb region embeds before rgb projector
#                     Freg_depth_trans,         # depth region embeds before rgb projector
#                     Freg_rgb_tran_projs,      # rgb region embeds after rgb projector
#                     Freg_depth_tran_projs)    # depth region embeds after rgb projector

# --------------------------------------
# assume batch_size = 2
# input: rgb_image, depth_image, masks 
# output: rgb_region_emb, depth_region_emb
Fimg_rgb = torch.randn(2, 729, 1152) # (B, L, F)    # Output of visusal encoder
Fimg_depth = torch.randn(2, 729, 1152) # (B, L, F)  # Output of visual encoder
masks = [                                           # (B, n_masks, W, H)
            torch.randn(5, 384, 384),               # 5 masks of sample 1 (assume batch size 2)
            torch.randn(3, 384, 384)                # 3 masks of sample 2 (assume batch size 2)
        ] 

# # Config model
# region_extractor_cfg = RegionExtractorConfig(region_extractor_type="regiongpt")
# config = PretrainedConfig()
# config.mm_hidden_size = 1152
# config.hidden_size = 4096

# # Init model
# region_extractor = RegionExtractor(region_extractor_cfg, config) 

# # output
# hres_tower_features, lres_tower_features = region_extractor.feature_refinement(Fimg_rgb)       # get high features and low features
# print(f"hres_tower_features: {hres_tower_features.shape}")
# print(f"lres_tower_features: {lres_tower_features.shape}")
# # hres_tower_features: (B, 11664, 1152)
# # lres_tower_features: (B, 729, 1152)

# # reg_rgb_emb, reg_depth_emb = region_extractor(hres_tower_features, Fimg_depth, masks) # original
# (CLS, 
#  Freg_rgb_trans,
#  Freg_depth_trans,
#  Freg_rgb_embs,
#  Freg_depth_embs) = region_extractor(lres_tower_features,   # add low feature for region cross_decoder
#                                      hres_tower_features,
#                                      Fimg_depth,
#                                      masks) 
# print(f"CLS: {CLS.shape}")
# print(f"Freg_rgb_trans: len {len(Freg_rgb_trans)} - {Freg_rgb_trans[0].shape} - {Freg_rgb_trans[1].shape}")         # output of RegionTransformer for rgb region
# print(f"Freg_rgb_trans: len {len(Freg_depth_trans)} - {Freg_depth_trans[0].shape} - {Freg_depth_trans[1].shape}")   # output of RegionTransformer for depth region
# print(f"Freg_rgb_trans_proj: len {len(Freg_rgb_embs)} - {Freg_rgb_embs[0].shape} - {Freg_rgb_embs[1].shape}")       # output of rgb_projector for rgb region
# print(f"Freg_rgb_trans_proj: len {len(Freg_depth_embs)} - {Freg_depth_embs[0].shape} - {Freg_depth_embs[1].shape}") # output of depth_projector for rgb region


hres_tower_features: torch.Size([2, 11664, 1152])
lres_tower_features: torch.Size([2, 729, 1152])
CLS: torch.Size([2, 1, 1152])
Freg_rgb_trans: len 2 - torch.Size([1, 5, 1152]) - torch.Size([1, 3, 1152])
Freg_rgb_trans: len 2 - torch.Size([1, 5, 1152]) - torch.Size([1, 3, 1152])
Freg_rgb_trans_proj: len 2 - torch.Size([1, 5, 4096]) - torch.Size([1, 3, 4096])
Freg_rgb_trans_proj: len 2 - torch.Size([1, 5, 4096]) - torch.Size([1, 3, 4096])


In [2]:
# flow in prepare_inputs_labels_for_multimodal() function

from llava.model.region_extractor.base_extractor import RegionExtractor, RegionExtractorConfig
from llava.model.region_transformer import RegionFeatureExtractor
from llava.model.region_heads import RegionClassifier, DistanceHead, MultipleChoiceHead, CountingHead, LeftRightHead


print("\n--- 1. Defining Inputs & Config ---")
# --- Simulation Parameters ---
BATCH_SIZE = 2
NUM_MASKS_SAMPLE_1 = 5
NUM_MASKS_SAMPLE_2 = 3
TOTAL_MASKS = NUM_MASKS_SAMPLE_1 + NUM_MASKS_SAMPLE_2
GLOBAL_IMG_PATCHES = 729 # (e.g., 27x27 grid from ada_pooling in original RegionExtractor)
MM_HIDDEN_SIZE = 1152 # Feature dimension from Vision Tower / RegionExtractor
LLM_HIDDEN_SIZE = 4096 # Feature dimension for the LLM

# --- Simulated Tensors ---
# These represent the outputs from the Vision Tower
Fimg_rgb_patches = torch.randn(BATCH_SIZE, GLOBAL_IMG_PATCHES, MM_HIDDEN_SIZE)
Fimg_depth_patches = torch.randn(BATCH_SIZE, GLOBAL_IMG_PATCHES, MM_HIDDEN_SIZE)

# This represents the list of binary masks from the dataloader
masks_list = [
    torch.rand(NUM_MASKS_SAMPLE_1, 384, 384), # 5 masks for sample 1
    torch.rand(NUM_MASKS_SAMPLE_2, 384, 384)  # 3 masks for sample 2
]

print(f"Simulated Fimg_rgb (global patches): {Fimg_rgb_patches.shape}")
print(f"Simulated Fimg_depth (global patches): {Fimg_depth_patches.shape}")
print(f"Simulated masks_list: {len(masks_list)} items, first item shape: {masks_list[0].shape}, \n\t\t\t\tsecond item shape: {masks_list[1].shape}")


# --- Model Configurations ---
region_extractor_cfg = RegionExtractorConfig(region_extractor_type="regiongpt")
config = PretrainedConfig()
config.mm_hidden_size = MM_HIDDEN_SIZE
config.hidden_size = LLM_HIDDEN_SIZE # LLM dimension

print("\n--- 2. Instantiate Modules ---")
# --- Instantiate Original and New Modules ---
# This is the original module responsible for pooling features from masks
region_extractor = RegionExtractor(region_extractor_cfg, config).eval()

region_feature_extractor_new = RegionFeatureExtractor(
    dim=MM_HIDDEN_SIZE, # Operates on 1152-dim features
    num_heads=8,        
    num_transformer_layers=6, 
    num_cross_attn_layers=1   
).eval()

# initialize head
# RegionHead for classify region
num_region_classes = 10 # Example: pallet, shelf, transporter, etc.
max_object_count = 15   # Example: max number of objects to count in a scene
region_head = RegionClassifier(
    infeatures=MM_HIDDEN_SIZE, # Takes the 1152-dim features
    nclasses=num_region_classes,
).eval()

# DistanceHead for messure the distance of 2 region
distance_head = DistanceHead(
    infeatures = MM_HIDDEN_SIZE
).eval()

# LeftRightHead for classify leftright of 2 region
leftright_head = LeftRightHead(
    infeatures=MM_HIDDEN_SIZE
).eval()

multiplechoice_head = MultipleChoiceHead(
    infeatures=MM_HIDDEN_SIZE
).eval()

counting_head = CountingHead(
    infeatures=MM_HIDDEN_SIZE
).eval()
print("Modules instantiated successfully.")

print("\n--- 3. Simulating Forward Pass Data Flow ---")
print("\n--- Step A: Feature Pooling (Simulating RegionExtractor) ---")
hres_tower_features_rgb, lres_tower_features_rgb = region_extractor.feature_refinement(Fimg_rgb_patches)
print(f"hres_tower_features_rgb (for RGB mask pooling): {hres_tower_features_rgb.shape}")
print(f"lres_tower_features_rgb (for global projector): {lres_tower_features_rgb.shape}")
print(f"Fimg_depth_patches (for depth mask pooling): {Fimg_depth_patches.shape}")


unprojected_rgb_regions_list = region_extractor.mask_pooling(hres_tower_features_rgb, masks_list, return_list=True)
unprojected_depth_regions_list = region_extractor.mask_pooling(Fimg_depth_patches, masks_list, return_list=True)
print(f"Unprojected RGB region features (sample 0 - 1): {unprojected_rgb_regions_list[0].shape, unprojected_rgb_regions_list[1].shape}") # Should be [NUM_MASKS_SAMPLE, 1152]
print(f"Unprojected Depth region features (sample 0): {unprojected_depth_regions_list[0].shape, unprojected_depth_regions_list[1].shape}")# Should be [NUM_MASKS_SAMPLE, 1152]

print("\n--- Step B: Region Interaction (RegionFeatureExtractor) ---")
# For batch processing, Need pad and stack these.
# Let's test with the first sample for simplicity.
unprojected_rgb_regions_s0 = unprojected_rgb_regions_list[0]
unprojected_depth_regions_s0 = unprojected_depth_regions_list[0]

global_context_features_s0 = lres_tower_features_rgb[0]


# module enhances region features using self- and cross-attention
enhanced_region_features_s0 = region_feature_extractor_new(
    rgb_features=unprojected_rgb_regions_s0,
    depth_features=unprojected_depth_regions_s0,
    image_features=global_context_features_s0
)
print(f"Enhanced region features (sample 0): {enhanced_region_features_s0.shape}") # Should be [2 * NUM_MASKS_SAMPLE_1, 1152]
print("\n--- Step C: Branching for LLM and Auxiliary Heads ---")

print("\n  --- Branch 1 (LLM Pathway) ---")
# Split the enhanced features back into RGB and Depth
num_masks_s0 = unprojected_rgb_regions_s0.shape[0]
enhanced_rgb_s0 = enhanced_region_features_s0[:num_masks_s0]
enhanced_depth_s0 = enhanced_region_features_s0[num_masks_s0:]
print(f"Enhanced RGB region features (sample 0): {enhanced_rgb_s0.shape}") # Should be [NUM_MASKS_SAMPLE_1, 1152]
print(f"Enhanced Depth region features (sample 0): {enhanced_depth_s0.shape}") # Should be [NUM_MASKS_SAMPLE_1, 1152]


# Pass through the original projectors to get LLM-compatible embeddings
projected_rgb_for_llm_s0 = region_extractor.rgb_projector(enhanced_rgb_s0)
projected_depth_for_llm_s0 = region_extractor.depth_projector(enhanced_depth_s0)
print(f"Projected RGB for LLM (sample 0): {projected_rgb_for_llm_s0.shape}") # Should be [num_masks, 4096]
print(f"Projected Depth for LLM (sample 0): {projected_depth_for_llm_s0.shape}") # Should be [num_masks, 4096]

# --- Branch 2: Features for the Auxiliary Task Heads ---
print("\n  --- Branch 2 (Auxiliary Heads) ---")
# The task_heads module takes the enhanced features and applies different heads.
# Let's test each head individually.

# 2.1 - Region Classifier: take all region on image and classify each region
region_class_logits = region_head(enhanced_region_features_s0)
print(f"Region Classifier inputs (sample 0): {enhanced_region_features_s0.shape}") # Should be [num_masks*2, 1152]
print(f"Region Classifier logits (sample 0): {region_class_logits.shape}") # Should be [num_masks, num_region_classes]

# 2.2.1 - Distance Head: take 2 region and predict distance for each region
# This head expects features for exactly two regions.
# Let's simulate taking the first two regions from the enhanced features.
two_region_features = (enhanced_region_features_s0[[0, 1]],enhanced_region_features_s0[[0+num_masks_s0, 1+num_masks_s0]]) # Grabbing RGB and Depth for first 2 masks
distance_prediction = distance_head(two_region_features)
print(f"Distance Head inputs (sample 0): tuple of {len(two_region_features)} - for RGB: {two_region_features[0].shape} - for Depth: {two_region_features[1].shape}") # tuple of (rgb_features, depth_features) each of shape (2, F)
print(f"Distance Head prediction (sample 0): {distance_prediction}") # Should be sclar (a single distance value)


# 2.2.2 - Left/Right Head: take 2 region and predict left/right for each region
left_right_logits = leftright_head(two_region_features)
print(f"Left/Right Head inputs (sample 0): tuple of {len(two_region_features)} - for RGB: {two_region_features[0].shape} - for Depth: {two_region_features[1].shape}") # tuple of (rgb_features, depth_features) each of shape (2, F)
print(f"Left/Right Head logits (sample 0): {left_right_logits}") # Should be [1, 2] (logits for left/right classes)


# 2.2.3 - MultipleChoice Head: This head takes features for all choices and predicts which one is correct.
all_region_features = (enhanced_region_features_s0[:num_masks_s0], enhanced_region_features_s0[num_masks_s0:])
mcq_logits = multiplechoice_head(all_region_features)
print(f"MultipleChoice Head inputs (sample 0): tuple of {len(all_region_features)} - for RGB: {all_region_features[0].shape} - for Depth: {all_region_features[1].shape}") # tuple of (rgb_features, depth_features) each of shape (Nr, F)
print(f"MultipleChoice Head logits (sample 0): {mcq_logits}") # Should be [num_masks]

# 2.2.4 - Counting Head: This head fuses all region features and predicts a count.
count_logits = counting_head(all_region_features)
print(f"Counting Head inputs (sample 0): tuple of {len(all_region_features)} - for RGB: {all_region_features[0].shape} - for Depth: {all_region_features[1].shape}") # tuple of (rgb_features, depth_features) each of shape (Nr, F)
print(f"Counting Head logits (sample 0): {count_logits}")

print("\n--- Data Flow Test Complete ---")


--- 1. Defining Inputs & Config ---
Simulated Fimg_rgb (global patches): torch.Size([2, 729, 1152])
Simulated Fimg_depth (global patches): torch.Size([2, 729, 1152])
Simulated masks_list: 2 items, first item shape: torch.Size([5, 384, 384]), 
				second item shape: torch.Size([3, 384, 384])

--- 2. Instantiate Modules ---
Modules instantiated successfully.

--- 3. Simulating Forward Pass Data Flow ---

--- Step A: Feature Pooling (Simulating RegionExtractor) ---
hres_tower_features_rgb (for RGB mask pooling): torch.Size([2, 11664, 1152])
lres_tower_features_rgb (for global projector): torch.Size([2, 729, 1152])
Fimg_depth_patches (for depth mask pooling): torch.Size([2, 729, 1152])
Unprojected RGB region features (sample 0 - 1): (torch.Size([5, 1152]), torch.Size([3, 1152]))
Unprojected Depth region features (sample 0): (torch.Size([5, 1152]), torch.Size([3, 1152]))

--- Step B: Region Interaction (RegionFeatureExtractor) ---
Enhanced region features (sample 0): torch.Size([10, 1152])

In [39]:
enhanced_region_features_s0[[0, 1, 5, 6]].shape

torch.Size([4, 1152])

In [40]:
enhanced_region_features_s0.shape

torch.Size([10, 1152])