Paper: (🚀🚀🚀 Accepted by CVPR2025 🚀🚀🚀):
-
MTMask6M:
- Coming soon
-
Original Document Data Sources:
Ensure you have Python 3.8 or higher installed in your environment.
git clone https://github.com/Token-family/Marten.git
cd Marten
pip install -r requirements.txtUse OCR engines(PaddleOCR,CRAFT) to generate word-level bounding boxes in the format:
/path/to/image\t[[x1_1,y1_1,x2_1,y2_1,x3_1,y3_1,x4_1,y4_1], ... ,[x1_n,y1_n,x2_n,y2_n,x3_n,y3_n,x4_n,y4_n]]
python mask_utils/mask_generation.pyReference InternVL2 for complete format specifications:
{
"id": 1,
"image": "/path/to/image",
"mask_path": "/path/to/mask",
"conversations":[
{
"from": "human",
"value": "<image>\nRecognize all text:",
},
{
"from": "gpt",
"value":"Fill in the visual text content here",
}
]
}
Follow InternVL2 methodology:
Pre-training
bash ./shell/marten_internlm2_intervit_pretrain.shFine-tuning
bash ./shell/marten_internlm2_intervit_finetune.shTraining with MGM
If you want to integrate the MGM module into your own model structure, you can refer to the code.
import torch
import torch.nn as nn
import torch.nn.functional as F
...
from transformers.modeling_utils import PreTrainedModel
from ..marten_module.MGM import MGM
def dice_loss(pred, target, smooth=1e-6):
"""
计算二分类问题的 Dice Loss
:param pred: 预测结果, 形状为 [N, 1, H, W]
:param target: 真实标签, 形状为 [N, 1, H, W]
:param smooth: 平滑项,防止除零
:return: Dice Loss 值
"""
pred = torch.sigmoid(pred)
pred_flat = pred.view(-1)
target_flat = target.view(-1)
intersection = (pred_flat * target_flat).sum()
union = pred_flat.sum() + target_flat.sum()
dice = (2. * intersection + smooth) / (union + smooth)
dice_loss = 1 - dice
return dice_loss
class CustomModel(PreTrainedModel):
def __init__(self, config, ..., use_mgm=False):
...
llm_hidden_size = config.llm_config.hidden_size
self.use_mgm = use_mgm
if self.use_mgm:
self.MGM_Decoder = MGM(llm_hidden_size, hidden_size=512, dev_convs_nums=4, out_channels=1, layer_num=4)
self.MGM_Decoder._initialize_weights()
self.MGM_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.5]))
self.MGM_aug_loss = dice_loss
self.select_llm_layer_idx = -4
def forward(
self,
...
pixel_values_mask: Optional[bool] = None,
...
):
...
"""
image_llm_hidden_features: Image-related features in the hidden layer of LLM
text_llm_hidden_features: Text-related features in the hidden layer of LLM
output_size: equal to image size
image_patch_size: The size of patch of image
image_token_num: The number of image tokens
loss: LLM original loss
"""
if self.use_mgm and pixel_values_masks is not None:
# The other parameters are customized for InternVL dynamic slicing. If you use other VFMs, you can delete them.
MGM_output = self.MGM_Decoder(image_llm_hidden_features, text_llm_hidden_features, output_size, image_patch_size, image_token_num)
loss += self.MGM_loss(MGM_output, pixel_values_mask, )
loss += self.dice_loss(MGM_output, pixel_values_mask.float().long())bash ./shell/eval.sh- Release training / evaluation code for Marten series
- Release code for mask generation
- Release dataset of MTMask6M
Marten is built with reference to the code of the following projects: InternVL2

