# **LayoutLM V2**


### Objectives for pre-training :-

2 new training objectives compared to V1:- (Other than Maksed Visual Language Modelling (MVLM))

1. Text-image alignment 
    - Some tokens lines are randomly selected, and their image regions are covered on the document image.

2. Text-image matching
    - Feed the output representation at [CLS] into a classifier to predict whether the image and text are from the same document page.
    - A classification layer is built above the encoder outputs. This layer predicts a label for each text token depending on whether it is covered or not



### Components of V2:- 

1. Text embedding:- 

        ti = TokEmb(wi) + PosEmb1D(i) + SegEmb(si)


2. Visual embedding:-

        vi = Proj(VisTokEmb(I)i) + PosEmb1D(i) + SegEmb([C])
        
    Here, the document page image is resized to 224x244 and then the feature maps are avg pooled with width W and height H.
    Layer these are flattened into W*H shape and then passed through a projection layer.

    Positional encodings are added as the CNN based vision encoder is not able to capture the positional information.

    All vision tokens segement = C


3. Layout Embedding:- 

        li = Concat(PosEmb2Dx(xmin, xmax, width), PosEmb2Dy(ymin, ymax, height))

   2 embedding layers to encode x and y features.
   Given the normalized bounding box of the i-th (0 ≤ i < WH + L) 
   
   text/visual token boxi = (xmin, xmax, ymin, ymax, width, height)





### Spatial-Aware Self attention :- 

#### Relative positonal embeddings 
**Modelled the semantic relative position and spatial relative position as bias terms to prevent adding too many parameters**

    α′ij = αij + bj-i(1D) + bxj-xi(2Dx) + byj-yi(2Dy)

These bias terms are learnable 

# **Flow :-**


#### Initialising the image embeddings and other layers:- 

    self.has_visual_segment_embedding = config.has_visual_segment_embedding
    self.embeddings = LayoutLMv2Embeddings(config)

    self.visual = LayoutLMv2VisualBackbone(config)

    self.visual_proj = nn.Linear(config.image_feature_pool_shape[-1], config.hidden_size)

    if self.has_visual_segment_embedding:
        self.visual_segment_embedding = nn.Parameter(nn.Embedding(1, config.hidden_size).weight[0])
    
    self.visual_LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
    self.visual_dropout = nn.Dropout(config.hidden_dropout_prob)

    self.encoder = LayoutLMv2Encoder(config)
    self.pooler = LayoutLMv2Pooler(config)


## 1. Image processing:- 
Tesseract OCR is applied on the on the document image.

- Words and normalised bounding boxes are obtianed from this step  
- If needed the image will be resized and the color channels will be flipped from RGB to BGR


###  **Here, the LayoutLMv2VisualBackbone is 'detectron2' model for extracting the visual features from the document image**

### Image features and embeddings 

#### Obtaining the image embeddings:- 

    def _calc_img_embeddings(self, image, bbox, position_ids):

        visual_embeddings = self.visual_proj(self.visual(image))

        position_embeddings = self.embeddings.position_embeddings(position_ids)

        spatial_position_embeddings = self.embeddings._calc_spatial_position_embeddings(bbox)

        embeddings = visual_embeddings + position_embeddings + spatial_position_embeddings
        
        if self.has_visual_segment_embedding:
            embeddings += self.visual_segment_embedding

        embeddings = self.visual_LayerNorm(embeddings)
        embeddings = self.visual_dropout(embeddings)
        
        return embeddings



## 2.  Textual processing:- 
**LayoutLMv2TokenizerFast**, which turns the words and bounding boxes into token-level **input_ids**,
**attention_mask**, **token_type_ids**, **bbox**. 

### Text embeddings and features




### Obtaining text embeddings :- 
        def _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids, inputs_embeds=None):

           position_embeddings = self.embeddings.position_embeddings(position_ids)

           spatial_position_embeddings = self.embeddings._calc_spatial_position_embeddings(bbox)

           token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)

           embeddings = inputs_embeds + position_embeddings + spatial_position_embeddings + token_type_embeddings

           embeddings = self.embeddings.LayerNorm(embeddings)
           embeddings = self.embeddings.dropout(embeddings)

           return embeddings

## 3. Prepearing for the encoder output:-

    visual_bbox = self._calc_visual_bbox(self.config.image_feature_pool_shape, bbox, device, final_shape)


Where, the visual boxes are calculated in this manner:- 

    def _calc_visual_bbox(self, image_feature_pool_shape, bbox, device, final_shape):
        visual_bbox_x = torch.div(
            torch.arange(
                0,
                1000 * (image_feature_pool_shape[1] + 1),
                1000,
                device=device,
                dtype=bbox.dtype,
            ),
            self.config.image_feature_pool_shape[1],
            rounding_mode="floor",
        )
        visual_bbox_y = torch.div(
            torch.arange(
                0,
                1000 * (self.config.image_feature_pool_shape[0] + 1),
                1000,
                device=device,
                dtype=bbox.dtype,
            ),
            self.config.image_feature_pool_shape[0],
            rounding_mode="floor",
        )
        visual_bbox = torch.stack(
            [
                visual_bbox_x[:-1].repeat(image_feature_pool_shape[0], 1),
                visual_bbox_y[:-1].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
                visual_bbox_x[1:].repeat(image_feature_pool_shape[0], 1),
                visual_bbox_y[1:].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
            ],
            dim=-1,
        ).view(-1, bbox.size(-1))

        visual_bbox = visual_bbox.repeat(final_shape[0], 1, 1)
        return visual_bbox



    final_bbox = torch.cat([bbox, visual_bbox], dim=1)

Here, bbox is the input bounding box passed to the model for inference. (Optional input)

    if bbox is None:
        bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)

    text_layout_emb = self._calc_text_embeddings(
        input_ids=input_ids,
        bbox=bbox,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        inputs_embeds=inputs_embeds,
    )

    visual_emb = self._calc_img_embeddings(
        image=image,
        bbox=visual_bbox,
        position_ids=visual_position_ids,
    )
    
    final_emb = torch.cat([text_layout_emb, visual_emb], dim=1)

## 4.  Passing the prepared inputs to the encoder :- 

    encoder_outputs = self.encoder(
        final_emb,
        extended_attention_mask,
        bbox=final_bbox,
        position_ids=final_position_ids,
        head_mask=head_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    sequence_output = encoder_outputs[0]
    pooled_output = self.pooler(sequence_output)

NOTE:-
**Here, the bbox is used to calculate the Relative 2d position embeddings in the modified self-attention mechanism**

#### 2D positional embeddings (Relative)

    def _calculate_2d_position_embeddings(self, bbox):
        position_coord_x = bbox[:, :, 0]
        position_coord_y = bbox[:, :, 3]
        rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
        rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
        rel_pos_x = relative_position_bucket(
            rel_pos_x_2d_mat,
            num_buckets=self.rel_2d_pos_bins,
            max_distance=self.max_rel_2d_pos,
        )
        rel_pos_y = relative_position_bucket(
            rel_pos_y_2d_mat,
            num_buckets=self.rel_2d_pos_bins,
            max_distance=self.max_rel_2d_pos,
        )
        rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
        rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
        rel_pos_x = rel_pos_x.contiguous()
        rel_pos_y = rel_pos_y.contiguous()
        rel_2d_pos = rel_pos_x + rel_pos_y
        return rel_2d_pos

#### 1D postional embeddings  (Relative)

    def _calculate_1d_position_embeddings(self, position_ids):
        rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
        rel_pos = relative_position_bucket(
            rel_pos_mat,
            num_buckets=self.rel_pos_bins,
            max_distance=self.max_rel_pos,
        )
        rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
        rel_pos = rel_pos.contiguous()
        return rel_pos

### Attention modification 

**Spatially self aware attention**

    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        
        if self.has_relative_attention_bias:
            attention_scores += rel_pos

        if self.has_spatial_attention_bias:
            attention_scores += rel_2d_pos
