本方案参考F-VLM设计

需要微调

In [None]:
import cv2
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR

from transformers.models.clip.modeling_clip import CLIPVisionTransformer
from transformers import CLIPProcessor, CLIPModel

from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
from detectron2.modeling.proposal_generator.rpn import RPN
from detectron2.modeling.roi_heads.roi_heads import StandardROIHeads
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import Boxes, Instances
from detectron2.data.datasets import register_coco_instances

In [None]:
val_json = "../data/COCO/annotations/instances_val2017.json"
val_images = "../data/COCO/val2017"

train_json = "../data/COCO/annotations/instances_train2017.json"
train_images = "../data/COCO/train2017"

val_small_json = "./val_small.json"

MRCNN_PATH = "../model/model_final_f10217.pkl"
CONFIG_FILE = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"

CLIP_PATH = "../model/clip-vit-patch32/models--openai--clip-vit-base-patch32/snapshots/3d74acf9a28c67741b2f4f2ea7635f0aaf6f0268"
MyMRCNN_PATH = "../model/my_mask_rcnn.pkl"
MyMRCNN_CLSFREE_PATH = "../model/my_clsfree_mask_rcnn.pkl"

实现余弦退火调度器

In [None]:
def cosine_scheduler_with_warmup(optimizer, warmup_steps, total_steps, min_lr_ratio=0.1):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = (current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
        return min_lr_ratio + (1 - min_lr_ratio) * cosine_decay
    return LambdaLR(optimizer, lr_lambda)

这里有一部分元件需要从头训练，因此给出两种初始化代码备用

In [None]:
def init_weights_kaiming(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
            
def init_weights_xavier(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

这里还存在对比损失，这里也给出 InfoNCE 对比损失的实现

In [None]:
def contrastive_loss(scores: torch.Tensor, gt_classes: torch.Tensor, temperature: float = 0.07):
    # softmax 温度缩放
    logits = scores / temperature  

    # 使用交叉熵作为对比损失
    loss = F.cross_entropy(logits, gt_classes)
    return loss

按照 F-VLM 的设计，CLIP可以分为两部分，分别是特征提取与最后的Pooling层。

这里将其分出来

In [None]:
class CLIPVisionTransformerSplit(CLIPVisionTransformer):
    def forward_features(self, pixel_values):
        """对应 feature extractor 部分"""
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)
        encoder_outputs = self.encoder(inputs_embeds=hidden_states)
        return encoder_outputs[0]  # last_hidden_state

    def forward_pool(self, last_hidden_state):
        """对应 last feature pooling layer 部分，这一部分的输出不能直接用，还需要 visual projection 来投影"""
        pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.post_layernorm(pooled_output)
        return pooled_output

In [None]:
class TwoStageCLIPModel(CLIPModel):
    """
    拓展版 CLIP 模型，支持显式分离视觉编码的两个阶段：
    1. Feature Extractor (patch embedding + transformer encoder)
    2. Last Feature Pooling Layer (CLS pooling + LayerNorm)
    """

    def __init__(self, config):
        super().__init__(config)
        self.vision_model = CLIPVisionTransformerSplit(config.vision_config)
        
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

        vision_config = model.config.vision_config
        new_vision_model = CLIPVisionTransformerSplit(vision_config)
        new_vision_model.load_state_dict(model.vision_model.state_dict())
        model.vision_model = new_vision_model

        return model

    def get_image_features_stage1(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
        """
        获取图像的 patch-level 特征 (Feature Extractor 输出)
        对应 self.vision_model.embeddings + self.vision_model.encoder
        这里的输出中，seq_len维度的首位是CLS，其余是图像各分块的embeddings。
        """
        return self.vision_model.forward_features(pixel_values)  # shape: (batch, seq_len, hidden_dim)

    def get_image_features_stage2(self, last_hidden_state: torch.FloatTensor) -> torch.FloatTensor:
        """
        从 stage1 输出计算 pooled image feature
        对应 self.vision_model.post_layernorm(CLS token) + self.visual_projection(pooled_state)
        """
        pooled_state = self.vision_model.forward_pool(last_hidden_state)  # shape: (batch, hidden_dim)
        return self.visual_projection(pooled_state)

原本的 RPN 设计的 forward 需要传入的东西拿不出来，这里新建子类重写之

In [None]:
class MyRPN(RPN):
    def forward(self, image_sizes, features, gt_instances):
        """
        原本传入 image 只为了 image_size，这次直接传入image_size，避免还要image
        """
        features = [features[f] for f in self.in_features]
        anchors = self.anchor_generator(features)

        pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features)
        # Transpose the Hi*Wi*A dimension to the middle:
        pred_objectness_logits = [
            # (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A)
            score.permute(0, 2, 3, 1).flatten(1)
            for score in pred_objectness_logits
        ]
        pred_anchor_deltas = [
            # (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) -> (N, Hi*Wi*A, B)
            x.view(x.shape[0], -1, self.anchor_generator.box_dim, x.shape[-2], x.shape[-1])
            .permute(0, 3, 4, 1, 2)
            .flatten(1, -2)
            for x in pred_anchor_deltas
        ]

        if self.training:
            assert gt_instances is not None, "RPN requires gt_instances in training!"
            gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances)
            losses = self.losses(
                anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes
            )
        else:
            losses = {}
        proposals = self.predict_proposals(
            anchors, pred_objectness_logits, pred_anchor_deltas, image_sizes
        )
        return proposals, losses

Box Predictor需要重写 forward 与 loss

loss 的类别应该改为对比损失

In [None]:
class MyBoxPredictor(FastRCNNOutputLayers):
    def forward(self, x):
        """
        Args:
            x: per-region features of shape (N, ...) for N bounding boxes to predict.

        Returns:
            (Tensor, Tensor):
            First tensor: shape (N,K+1), scores for each of the N box. Each row contains the
            scores for K object categories and 1 background class.

            Second tensor: bounding box regression deltas for each box. Shape is shape (N,Kx4),
            or (N,4) for class-agnostic regression.
        """
        if x.dim() > 2:
            x = torch.flatten(x, start_dim=1)
        # scores = self.cls_score(x)  # scores不需要有
        proposal_deltas = self.bbox_pred(x)
        return None, proposal_deltas
    
    def losses(self, predictions, proposals):
        """
        Args:
            predictions: return values of :meth:`forward()`.
            proposals (list[Instances]): proposals that match the features that were used
                to compute predictions. The fields ``proposal_boxes``, ``gt_boxes``,
                ``gt_classes`` are expected.

        Returns:
            Dict[str, Tensor]: dict of losses
        """
        scores, proposal_deltas = predictions

        # parse classification outputs
        gt_classes = torch.cat([p.gt_classes for p in proposals], dim=0)

        # parse box regression outputs
        proposal_boxes = torch.cat([p.proposal_boxes.tensor for p in proposals], dim=0)  # Nx4
        assert not proposal_boxes.requires_grad, "Proposals should not require gradients!"
        # If "gt_boxes" does not exist, the proposals must be all negative and
        # should not be included in regression loss computation.
        # Here we just use proposal_boxes as an arbitrary placeholder because its
        # value won't be used in self.box_reg_loss().
        gt_boxes = torch.cat(
            [(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals],
            dim=0,
        )

        losses = {
            "loss_cls": contrastive_loss(scores, gt_classes),
            "loss_box_reg": self.box_reg_loss(
                proposal_boxes, gt_boxes, proposal_deltas, gt_classes
            ),
        }
        return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}

重写 ROI Head 的逻辑

In [None]:
class MyROIHeads(StandardROIHeads):
    def __init__(self, *, box_in_features, box_pooler, box_head, box_predictor: nn.Module, **kwargs):
        super().__init__(box_in_features=box_in_features, box_pooler=box_pooler, box_head=box_head,
                         box_predictor=box_predictor, **kwargs)
        self.projection = nn.Linear(1024, 512)  # 用于对齐的层

    def forward(self, features, proposals, text_embeddings, targets=None):
        if self.training:
            assert targets, "'targets' argument is required during training"
            proposals = self.label_and_sample_proposals(proposals, targets)
        del targets

        if self.training:
            losses = self._forward_box(features, proposals)
            # Usually the original proposals used by the box head are used by the mask, keypoint
            # heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes
            # predicted by the box head.
            losses.update(self._forward_mask(features, proposals))
            return proposals, losses
        else:
            pred_instances = self._forward_box(features, proposals, text_embeddings)
            # During inference cascaded prediction is used: the mask and keypoints heads are only
            # applied to the top scoring box detections.
            pred_instances = self.forward_with_given_boxes(features, pred_instances)
            return pred_instances, {}
        
    def _forward_box(self, features, proposals, text_embeddings):
        features = [features[f] for f in self.box_in_features]
        box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals])
        box_features = self.box_head(box_features)
        _, box_deltas = self.box_predictor(box_features)
        box_features = self.projection(box_features)
        box_features = F.normalize(box_features, p=2, dim=1)  # 这就是最终的 embedding 了
        
        scores = box_features @ text_embeddings.T

        if self.training:
            losses = self.box_predictor.losses((scores, box_deltas), proposals)
            # proposals is modified in-place below, so losses must be computed first.
            if self.train_on_pred_boxes:
                with torch.no_grad():
                    pred_boxes = self.box_predictor.predict_boxes_for_gt_classes(
                        (scores, box_deltas), proposals
                    )
                    for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes):
                        proposals_per_image.proposal_boxes = Boxes(pred_boxes_per_image)
            return losses
        else:
            pred_instances, _ = self.box_predictor.inference((scores, box_deltas), proposals)
            return pred_instances

对 Mask R-CNN 的实现

In [None]:
class MyMRCNN2(GeneralizedRCNN):
    def __init__(self, cfg):
        super().__init__(cfg)
        
    def forward(self, image_sizes, clip_feature, gt_instances, text_embeddings):
        if not self.training:
            return self.inference(clip_feature, gt_instances)
        
        features = self.backbone(clip_feature)
        proposals, proposal_losses = self.proposal_generator(image_sizes, features, gt_instances)  # 这里会产生初始框proposal的损失
        _, detector_losses = self.roi_heads(features, proposals, text_embeddings, gt_instances)
        
        """可视化不在这里进行了
        if self.vis_period > 0:
            storage = get_event_storage()
            if storage.iter % self.vis_period == 0:
                self.visualize_training(batched_inputs, proposals)
        """

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)
        return losses
        
    def inference(self, clip_feature, gt_instances):
        raise NotImplementedError()

    @classmethod
    def from_pretrained(cls, cfg_path, weight_path, device="cuda"):
        cfg = get_cfg()
        cfg.merge_from_file(cfg_path)
        cfg.MODEL.WEIGHTS = weight_path
        cfg.MODEL.DEVICE = device

        model = cls(cfg)
        checkpointer = DetectionCheckpointer(model)
        checkpointer.load(weight_path)
        
        # fpn 的 bottom_up 改为 CLIP 专属的 Adapter
        model.backbone.bottom_up = CLIPtoFPNAdapter()
        
        # 各类型改为更便利的子类
        model.proposal_generator.__class__ = MyRPN
        
        model.roi_heads.__class__ = MyROIHeads
        model.roi_heads.projection = nn.Linear(1024, 512)  # 还需要添加一个投影来对齐类别特征与文本embeddings
        
        model.roi_heads.box_predictor.__class__ = MyBoxPredictor
        model.roi_heads.box_predictor.cls_score = None  # 移除不需要的参数
        
        return model

CLIP stage 1 的输出无法直接交给FPN，因此需要一个Adapter

FPN的实现中有类似作用的定位，即 FPN 的bottom_up，可以用这个来替换

使用卷积(kernel_size=1时等同于全连接)+插值来构造多层特征

In [None]:
class CLIPtoFPNAdapter(nn.Module):
    def __init__(self, in_channels=768, out_channels_list=[256, 256, 256, 256]):
        super().__init__()
        self.convs = nn.ModuleList()
        for out_ch in out_channels_list:
            self.convs.append(nn.Conv2d(in_channels, out_ch, kernel_size=1))
            in_channels = out_ch
            
        self.apply(init_weights_xavier)

    def forward(self, x):
        # x: [B, 768, 7, 7]
        c5 = self.convs[0](x)  # [B, 256, 7, 7]
        c4 = self.convs[1](F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False))  # [B, 256, 14, 14]
        c3 = self.convs[2](F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False))  # [B, 256, 28, 28]
        c2 = self.convs[3](F.interpolate(x, scale_factor=8, mode='bilinear', align_corners=False))  # [B, 256, 56, 56]

        features = {
            "res2": c2,
            "res3": c3,
            "res4": c4,
            "res5": c5,
        }
        return features

Dataset和Dataloader总体沿用之前的

In [None]:
class MyDataset(Dataset):
    def __init__(self, dataset_name=None, dataset_dicts=None, meta_data=None):
        if dataset_name is None:
            if dataset_dicts is None or meta_data is None:
                raise ValueError("dataset_name为None时dataset_dicts与meta_data不能为None")
            else:
                self.dataset_dicts = dataset_dicts
                self.meta_data = meta_data
        else:
            self.dataset_dicts = DatasetCatalog.get(dataset_name)
            self.meta_data = MetadataCatalog.get(dataset_name)

    def __len__(self):
        return len(self.dataset_dicts)

    def __getitem__(self, idx):
        d = self.dataset_dicts[idx].copy()
        # 读取 cv2 图像
        org_img = cv2.imread(d["file_name"])
        # d['cv2'] = org_img  # 不需要了
        # 这一部分来自defaults的__call__
        # 转 tensor
        # img = self.aug.get_transform(org_img).apply_image(org_img)
        d['image'] = torch.as_tensor(org_img.astype("float32").transpose(2, 0, 1))
        
        # 模型还希望在训练时能够有'Instances'，这里也加上
        height, width = org_img.shape[:2]
        instances = Instances((height, width))
        boxes = []
        classes = []
        for ann in d["annotations"]:
            x, y, w, h = ann["bbox"]
            boxes.append([x, y, x + w, y + h])
            classes.append(ann["category_id"])
        instances.gt_boxes = Boxes(torch.tensor(boxes, dtype=torch.float32))
        instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
        d["instances"] = instances
        
        return d

def build_batch_loader(dataset_name=None, dataset_dicts=None, meta_data=None, batch_size=1, shuffle=False):
    if dataset_name is None:
        if dataset_dicts is None or meta_data is None:
            raise ValueError("dataset_name为None时dataset_dicts与meta_data不能为None")
        else:
            dataset = MyDataset(dataset_dicts=dataset_dicts, meta_data=meta_data)
    else:
        dataset = MyDataset(dataset_name=dataset_name)
    return DataLoader(
        dataset, 
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda batch: batch
    )

In [None]:
class MyFVLM(nn.Module):
    def __init__(self, clip_path, mask_rcnn_cfg_path, mask_rcnn_weight_path, device="cpu"):
        super().__init__()
        self.device = device
        
        self.clip_model = TwoStageCLIPModel.from_pretrained(clip_path).to(device)
        self.clip_processor = CLIPProcessor.from_pretrained(clip_path)
        if mask_rcnn_cfg_path is not None:
            self.mask_rcnn_model = MyMRCNN2.from_pretrained(mask_rcnn_cfg_path, mask_rcnn_weight_path).to(device)
        
        self.froze_VLM()
        
    def froze_VLM(self):
        for param in self.clip_model.parameters():
            param.requires_grad = False
    
    def forward(self, batched_inputs, text_embeddings):
        if not self.training:
            raise NotImplementedError()

        batch_imgs = [bat["image"] for bat in batched_inputs]
        image_sizes = [(bat["image"].shape[1], bat["image"].shape[2]) for bat in batched_inputs]
            
        clip_feature = self.clip_processor(images=batch_imgs, return_tensors='pt', padding=True).to(self.device)  # TODO: 训练前做好是否可能？
        gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
        
        loss_dict = self.mask_rcnn_model(image_sizes, clip_feature, gt_instances, text_embeddings)
        
        return loss_dict
            
    def inference(self, batch):
        raise NotImplementedError()
    
    def class_name_list_prepare(self, class_name_list):
        #class_name_list = ["a photo of " + cls for cls in class_name_list]  # TODO: zero-shot 测试？
        class_name_list.append("background or no object")  # TODO: 确认background怎么弄
        return class_name_list
        
    def get_cls_embedding(self, class_name_list):
        class_inputs = self.clip_processor(text=class_name_list, return_tensors="pt", padding=True).to(
            self.clip_model.device)
        return self.clip_model.get_text_features(**class_inputs).to(self.clip_model.device)
    
    def save(self, path):
        torch.save({"model": self.mask_rcnn_model.state_dict()}, path)
        
    def load(self, path, device):
        self.device = device
        checkpoint = torch.load(path, map_location=device)
        self.mask_rcnn_model.load_state_dict(checkpoint["model"])

定义训练保存机制

In [None]:
def save_training_state(model, optimizer, scheduler, epoch, path):
    torch.save({
        "epoch": epoch,
        "model": model.mask_rcnn_model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict() if scheduler else None
    }, path)


def load_training_state(model, optimizer, scheduler, path, device="cuda"):
    checkpoint = torch.load(path, map_location=device)
    if checkpoint.get("model") and model.mask_rcnn_model:
        model.mask_rcnn_model.load_state_dict(checkpoint["model"])
    if checkpoint.get("optimizer") and optimizer:
        optimizer.load_state_dict(checkpoint["optimizer"])
    if checkpoint.get("scheduler") and scheduler:
        scheduler.load_state_dict(checkpoint["scheduler"])
    epoch = checkpoint.get("epoch", 0)
    return epoch

In [None]:
def train(model, optimizer, scheduler, epoch_num, batch_size, dataset_name, dataset_path, dataset_json, shuffle=False):
    register_coco_instances(dataset_name, {}, dataset_json, dataset_path)
    dataset_dicts = DatasetCatalog.get(dataset_name)
    meta_data = MetadataCatalog.get(dataset_name)
    
    loader = build_batch_loader(dataset_dicts=dataset_dicts, meta_data=meta_data,
                                batch_size=batch_size, shuffle=shuffle)
    
    with torch.no_grad():
        class_name_list = model.class_name_list_prepare(meta_data.thing_classes)
        class_embeddings = model.get_cls_embedding(class_name_list)
    
    for epoch in range(epoch_num):
        for batch in loader:
            losses = model(batch, class_embeddings)
            
            