Skip to content

PriNing/Marten

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Marten: Visual Question Answering with Mask Generation for Multi-modal Document Understanding(CVPR 2025)

📖 Introduction

Paper: (🚀🚀🚀 Accepted by CVPR2025 🚀🚀🚀):

📄 MTMask6M

Datasets

📚 Usage

📦 Installation

Ensure you have Python 3.8 or higher installed in your environment.

git clone https://github.com/Token-family/Marten.git
cd Marten
pip install -r requirements.txt

🛠️ Creating Your Own Dataset

Step 1: Obtain Word-level Bounding Boxes

Use OCR engines(PaddleOCR,CRAFT) to generate word-level bounding boxes in the format:

/path/to/image\t[[x1_1,y1_1,x2_1,y2_1,x3_1,y3_1,x4_1,y4_1], ... ,[x1_n,y1_n,x2_n,y2_n,x3_n,y3_n,x4_n,y4_n]]

Step 2: Generate Masks

python mask_utils/mask_generation.py

Step 3: Data Format

Reference InternVL2 for complete format specifications:

{
    "id": 1,
    "image": "/path/to/image",
    "mask_path": "/path/to/mask",
    "conversations":[
        {
            "from": "human",
            "value": "<image>\nRecognize all text:",
        },
        {
            "from": "gpt",
            "value":"Fill in the visual text content here",
        }
    ]
}

🚀 Training

Follow InternVL2 methodology:

Pre-training

bash ./shell/marten_internlm2_intervit_pretrain.sh

Fine-tuning

bash ./shell/marten_internlm2_intervit_finetune.sh

Training with MGM

If you want to integrate the MGM module into your own model structure, you can refer to the code.

import torch
import torch.nn as nn
import torch.nn.functional as F
...
from transformers.modeling_utils import PreTrainedModel
from ..marten_module.MGM import MGM

def dice_loss(pred, target, smooth=1e-6):
    """
    计算二分类问题的 Dice Loss
    :param pred: 预测结果, 形状为 [N, 1, H, W]
    :param target: 真实标签, 形状为 [N, 1, H, W]
    :param smooth: 平滑项,防止除零
    :return: Dice Loss 值
    """
    pred = torch.sigmoid(pred)
    
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    
    intersection = (pred_flat * target_flat).sum()
    union = pred_flat.sum() + target_flat.sum()
    
    dice = (2. * intersection + smooth) / (union + smooth)
    
    dice_loss = 1 - dice
    
    return dice_loss


class CustomModel(PreTrainedModel):
    
    def __init__(self, config, ..., use_mgm=False):
        
        ...
        
        llm_hidden_size = config.llm_config.hidden_size
        self.use_mgm = use_mgm

        if self.use_mgm:
            self.MGM_Decoder = MGM(llm_hidden_size, hidden_size=512, dev_convs_nums=4, out_channels=1, layer_num=4)
            self.MGM_Decoder._initialize_weights()
            self.MGM_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.5]))    
            self.MGM_aug_loss = dice_loss 
            self.select_llm_layer_idx = -4

    def forward(
        self,
        ...
        pixel_values_mask: Optional[bool] = None,
        ...
    ):

        ...

        """
        image_llm_hidden_features: Image-related features in the hidden layer of LLM
        text_llm_hidden_features: Text-related features in the hidden layer of LLM
        output_size: equal to image size
        image_patch_size: The size of patch of image
        image_token_num: The number of image tokens
        loss: LLM original loss

        """

        if self.use_mgm and pixel_values_masks is not None:
            
            # The other parameters are customized for InternVL dynamic slicing. If you use other VFMs, you can delete them.
            MGM_output = self.MGM_Decoder(image_llm_hidden_features, text_llm_hidden_features, output_size, image_patch_size, image_token_num)  
            loss += self.MGM_loss(MGM_output, pixel_values_mask, )
            loss += self.dice_loss(MGM_output, pixel_values_mask.float().long())

🔍 Evaluate

bash ./shell/eval.sh

📌 TODO List

  • Release training / evaluation code for Marten series
  • Release code for mask generation
  • Release dataset of MTMask6M

🙏 Acknowledgement

Marten is built with reference to the code of the following projects: InternVL2

📜 Citation

About

Visual Question Answering with Mask generation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published