# Understanding Phi3ImageEmbedding Module (Phi-3-Vision)

The `Phi3ImageEmbedding` module plays a crucial role in how the Phi-3-Vision model processes and integrates visual information with textual data. Let's break down its components and functionality:

```python
class Phi3ImageEmbedding(nn.Module):
    def __init__(self, config, **kwargs):
        super().__init__()
        self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
        self.img_processor = VisionModel(CLIP_VIT_LARGE_PATCH14_336_CONFIG)
        self.num_img_tokens = config.img_processor['num_img_tokens']
        self.image_dim_out = image_dim_out = config.img_processor['image_dim_out']
        self.img_sizes = None
        self.use_hd_transform = kwargs.get('use_hd_transform', False)
        self.with_learnable_separator = kwargs.get('with_learnable_separator', False)
        self.hd_transform_order = kwargs.get('hd_transform_order', 'glb_sub')
        self.glb_GN = mx.zeros([1, 1, self.image_dim_out * 4])
        self.sub_GN = mx.zeros([1, 1, 1, self.image_dim_out * 4])
        self.img_projection = [nn.Linear(image_dim_out * 4, config.hidden_size), nn.GELU(), nn.Linear(config.hidden_size, config.hidden_size)]
        self.vocab_size = config.vocab_size
        self.img_features = None
        self.layer_idx = config.img_processor.get('layer_idx', -2)
        self.type_feature = config.img_processor.get('type_feature', 'patch')

    def get_img_features(self, img_embeds):
        LAYER_IDX = self.layer_idx
        TYPE_FEATURE = self.type_feature
        img_processor_output = self.img_processor(img_embeds.transpose(0,2,3,1), output_hidden_states=True)
        img_feature = img_processor_output[-1][LAYER_IDX]
        patch_feature = img_feature[:, 1:]
        return patch_feature

    def __call__(self, input_ids, img_embeds, img_sizes, positions):
        select = False
        if len(positions.tolist()) > 0:
            g_values = abs(input_ids[positions[:, 0], positions[:, 1]])
            if self.use_hd_transform and img_sizes is not None and len(img_sizes):
                hd_transform = True
                bs = img_embeds.shape[0]
                img_features = self.get_img_features(img_embeds.reshape(-1, *img_embeds.shape[2:]))
                base_feat_height = base_feat_width = int(img_features.shape[1] ** 0.5)
                img_features = img_features.reshape(bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
                C = self.image_dim_out
                H = base_feat_height
                output_imgs = []
                output_len = []
                for _bs in range(bs):
                    h, w = img_sizes[_bs].tolist()
                    h = h // 336 
                    w = w // 336
                    B_ = h * w
                    global_img_feature = img_features[_bs, :1]
                    glb_img = global_img_feature.reshape(1,H,H,C).reshape(1,H//2,2,H//2,2,C).transpose(0,1,3,2,4,5).reshape(1,H//2,H//2,4*C)
                    temp_glb_GN = mx.tile(self.sub_GN, (1, H//2, 1, 1))
                    glb_img = mx.concatenate([glb_img, temp_glb_GN], axis=2).reshape(1,-1,4*C)
                    sub_img = img_features[_bs, 1:]
                    sub_img = sub_img[:B_]
                    sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//2,2,H//2,2,C).transpose(0,1,3,2,4,5).reshape(B_,-1,4*C)
                    sub_img = sub_img.reshape(1, h, w, 12, 12, -1).transpose(0,1,3,2,4,5).reshape(1,h*12,w*12,4*C)
                    temp_sub_GN = mx.tile(self.sub_GN, (1, h*12, 1, 1))
                    sub_img = mx.concatenate([sub_img, temp_sub_GN], axis=2).reshape(1,-1,4*C)
                    output_imgs.append(mx.concatenate([sub_img, self.glb_GN, glb_img], axis=1))
                    temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
                    output_len.append(temp_len)
                num_img_tokens = output_len
                img_set_tensor = []
                for x in output_imgs:
                    for l in self.img_projection:
                        x = l(x)
                    img_set_tensor.append(x)
            elif img_embeds.ndim == 4:
                selected_g_values = g_values[::self.num_img_tokens]
                x = self.get_img_features(img_embeds).reshape(-1, self.image_dim_out)
                for l in self.img_projection:
                    x = l(x)
                img_set_tensor = x
            elif img_embeds.ndim == 3:
                selected_g_values = g_values[::self.num_img_tokens]
                x = img_embeds.view(-1, self.image_dim_out)
                for l in self.img_projection:
                    x = l(x)
                img_set_tensor = x
            else:
                raise NotImplementedError
            select = True
        input_ids = mx.clip(input_ids, 0, self.vocab_size)
        hidden_states = self.wte(input_ids)
        if select:
            if hd_transform:
                idx = 0
                for i, cnt in enumerate(num_img_tokens):
                    positions = positions.tolist()
                    hidden_states[positions[idx][0], positions[idx][1] : positions[idx][1] + cnt] = img_set_tensor[i]
                    idx += cnt
            else:
                idx = 0
                for i, g in enumerate(selected_g_values):
                    cnt = self.num_img_tokens
                    hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = (
                        img_set_tensor[i * cnt : (i + 1) * cnt]
                        )
                    idx += cnt
        return hidden_states
```