In [1]:

# coding: utf-8

import torch 
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F

from sam.segment_anything.modeling import ImageEncoderViT
from sam.segment_anything.modeling.common import LayerNorm2d

from functools import partial

from pprint import pprint 



In [7]:
class MaskDecoder(nn.Module):
    def __init__(self, in_chans):
        super().__init__()
        activation = nn.GELU
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_chans, in_chans//2, kernel_size=2, stride=2),
            LayerNorm2d(in_chans//2),
            activation(),
            nn.ConvTranspose2d(in_chans//2, in_chans//2//2, kernel_size=2, stride=2),
            # LayerNorm2d(in_chans//2//2),
            activation(),
            nn.ConvTranspose2d(in_chans//2//2, in_chans//2//2//2, kernel_size=2, stride=2),
            activation(),
            nn.ConvTranspose2d(in_chans//2//2//2, 1, kernel_size=2, stride=2)
        )
        
    def forward(self, x):
        return self.decoder(x)  # BCHW
        

In [51]:

class TopoDecoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.hidden_dim = cfg.TOPO_DECODER.HIDDEN_DIM
        self.num_heads = cfg.TOPO_DECODER.NUM_HEADS
        self.depth = cfg.TOPO_DECODER.DEPTH
        self.dim_ffn = cfg.TOPO_DECODER.DIM_FFN
        self.ROI_size = cfg.TOPO_DECODER.ROI_SIZE
        self.num_queries = cfg.TOPO_DECODER.NUM_QUERIES
        
        self.query_embed = nn.Embedding(self.num_queries, self.hidden_dim)
        
        # BUG 这里的rel_pos_embed实现不明确，或者说暂时不需要实现
        # self.rel_pos_embed = nn.Parameter(torch.randn(size=(self.ROI_size, self.ROI_size)), requires_grad=True)
        
        # XXX 这里到底该用EncoderLayer还是DecoderLayer还有待商榷，亦或者先encoder再decoder？
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=self.hidden_dim,
            nhead=self.num_heads,
            dim_feedforward=self.dim_ffn,
            dropout=0.1,
            activation='relu',
            batch_first=True
        )
        self.topo_decoder = nn.TransformerDecoder(decoder_layer, num_layers=self.depth)
        
        # TODO 看DETR如何输出固定数量的点（根据多个query输出多个点）
        self.output_proj = nn.Linear(self.hidden_dim, 3)  # 输出[p, Δx, Δy]
        

    def _get_valid_rc(self, r, c, H, W):
        '''由关键点的rc坐标输出对应有效范围的索引以及各自由此导致的需要padding的大小'''
        ROI_size = self.cfg.TOPO_DECODER.ROI_SIZE
        lst = [c-(ROI_size/2), c+(ROI_size/2), r-(ROI_size/2), r+(ROI_size/2)]
        left, right, up, down = [int(x) for x in lst]
        l_pad, r_pad, u_pad, d_pad = [0] * 4
        if left < 0:
            l_pad = -left
            left = 0
        if up < 0:
            u_pad = -up
            up = 0
        if right > W:
            r_pad = right - W
            right = W
        if down > H:
            d_pad = down - H
            down = H
            
        return (left, right, up, down), (l_pad, r_pad, u_pad, d_pad)
    
    
    def crop_ROI_feature(self, upsampled_features, keypoints):
        '''根据传入的多个keypoints截取相应的ROI_features'''
        # upsampled_features: [B, 128, H, W]
        # keypoints: [B, N_points, 2]
        # 左和上取16，右和下取15 -> 32x32
        B, C, H, W = upsampled_features.shape
        batch_cropped_features = []
        for sample in range(B):
            sample_cropped_features= []
            for point in keypoints[sample]:
                r, c = point
                (left, right, up, down), pad = self._get_valid_rc(r, c, H, W)
                single_cropped_feature = upsampled_features[sample, :, up:down, left:right]
                single_cropped_feature = F.pad(single_cropped_feature, pad=pad, mode='constant', value=0)
                sample_cropped_features.append(single_cropped_feature)
            sample_cropped_features = torch.stack(sample_cropped_features, dim=0)  # [N_points, C, H, W]
            batch_cropped_features.append(sample_cropped_features)
        batch_cropped_features = torch.stack(batch_cropped_features, dim=0) # [B, N_points, C, H, W]
        
        return batch_cropped_features
    
    
    def forward(self, image_embeddings, upsampled_features, keypoints=None, keypoints_valid=None, with_asb_PE=False):
        # image_embeddings: [B, 256, H, W]
        # upsampled_features: [B, 128, H, W]
        # keypoints: [B, N_points, 2]
        B, C, H, W = image_embeddings.shape
        image_embeddings = torch.detach(image_embeddings) 
        upsampled_features = torch.detach(upsampled_features)   # 使用detach, 因为对单个点的ROI的改动不应该反应到原始输入中，只是临时用用
        if with_asb_PE:
            # TODO 可能需要加入绝对位置编码
            pass
        # ROI_features = self.crop_ROI_feature(upsampled_features, keypoints) # [B, N_points, C, H, W]
        # BUG 注意这里Batch中每个样本的N_points可能不一样，这就导致一个batch中最终会出现不同的输入序列长度 -> 那就处理成一样的
        # TODO 解决每个batch的样本的N_points可能不一样的问题，否则就需要加入mask 
        N_points = self.cfg.TOPO_DECODER.NUM_POINTS
        assert N_points == keypoints.shape[1], "Input num_points not equal to pre-defined NUM_POINTS for per patch!"   # 即便一个样本没有这么多个点也要提前pad好
        ROI_features = upsampled_features[:, :, :, :].unsqueeze(1).repeat(1, N_points, 1, 1, 1)
        # print(ROI_features.shape)
        # TODO 还需要考量这里的valid如何才正确（对比samRoad）
        keypoints_valid = keypoints_valid.view(B * N_points, H*W)
        
        
        sum_C = C + ROI_features.shape[2]   # 256 + 128 = 384
        # print(sum_C)
        image_embeddings = image_embeddings.unsqueeze(1).repeat(1, N_points, 1, 1, 1)  #  -> [B, 1, C, H, W] -> [B, N_points, C, H, W] for concatenation
        # print(image_embeddings.shape)
        
        # [B,N_points,C,H,W] -> [B*N_points,C,H*W] -> [B*N_points,H*W,C]
        x = torch.concat([ROI_features, image_embeddings], dim=2).reshape(-1, sum_C, H*W).permute(0, 2, 1)
        # print(x.shape)
        # print(self.query_embed.weight.repeat(B, 1, 1).shape)
        
        # TODO 确定哪些位置是需要被mask掉的，也就是确定memory_key_padding_mask
        x = self.topo_decoder(tgt=self.query_embed.weight.repeat(B*N_points, 1, 1), memory=x, memory_key_padding_mask=keypoints_valid)   # TODO 这里也许需要pos_embed
        output_logits = self.output_proj(x) # num_queries个384维向量转成num_queries个3维向量
        
        return output_logits
            
            
# 测试输出是否正常
from utils import load_config

cfg_path = './config/R2RC.yml'
cfg = load_config(cfg_path)
model = TopoDecoder(cfg)

# image_embeddings = torch.rand(size=(4, 256, 32, 32))
# upsampled_features = torch.rand(size=(4, 128, 32, 32))
# y = model(image_embeddings, upsampled_features)
# print(y.shape)

# # TODO 测试一个batch内各个cropped的feat的后半程是否一样 -> yes!
# upsampled_features = torch.rand(size=(4, 1, 512, 512))
# image_embeddings = torch.rand(size=(4, 1, 32, 32))
# upsampled_features[0] = torch.ones((1, 512, 512))
# upsampled_features[1] = torch.ones((1, 512, 512)) + 1

# image_embeddings[0] = torch.ones((1, 32, 32)) + 10
# image_embeddings[1] = torch.ones((1, 32, 32)) + 20

# keypoints = [[[100, 100], [200, 200]], 
#              [[100, 100], [200, 200]], 
#              [[100, 100], [200, 200]], 
#              [[100, 100], [200, 200]]]

# batch_cropped_features = model.crop_ROI_feature(upsampled_features, keypoints)
# print(batch_cropped_features.shape)
# B, C, H, W = image_embeddings.shape
# sum_C = C + batch_cropped_features.shape[2]   # 256 + 128 = 384
# b, n_points, c, h, w = batch_cropped_features.shape
# image_embeddings = image_embeddings.unsqueeze(1).repeat(1, n_points, 1, 1, 1)   # -> [B, 1, C, H, W]
# print(image_embeddings.shape)
# x = torch.concat([batch_cropped_features, image_embeddings], dim=2)
# # x = torch.concat([batch_cropped_features, image_embeddings], dim=2).reshape(-1, sum_C, H*W).permute(0, 2, 1)
# print(x.shape)

# print(x[1, 0, 0])
# print(x[1, 0, 1])
# print(x[1, 1, 0])
# print(x[1, 1, 1])



In [59]:
torch.set_printoptions(threshold=5000, edgeitems=3, linewidth=100)
r, c = 1, 1
cfg.TOPO_DECODER.ROI_SIZE = 4
a = TopoDecoder(cfg)
(left, right, up, down), pad = a._get_valid_rc(r, c, 4, 4)
feat = torch.rand((4, 4))
print(feat)
cropped_feat = feat[up:down, left:right]
print(cropped_feat.shape)
print(F.pad(cropped_feat, pad))
print(F.pad(cropped_feat, pad).shape)
 

tensor([[0.4230, 0.6379, 0.0859, 0.4980],
        [0.8984, 0.1095, 0.3243, 0.9580],
        [0.2919, 0.7649, 0.9141, 0.8269],
        [0.4134, 0.0099, 0.0915, 0.6846]])
torch.Size([3, 3])
tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4230, 0.6379, 0.0859],
        [0.0000, 0.8984, 0.1095, 0.3243],
        [0.0000, 0.2919, 0.7649, 0.9141]])
torch.Size([4, 4])


In [75]:
class FeatureUpsampler(nn.Module):
    def __init__(self, cfg, in_chans=256):
        super().__init__()
        self.cfg = cfg
        # NOTE 或许可以改成可学习式的上采样器，但考虑到参数激增问题先实现一个插值的就好
        # activation = nn.GELU
        # self.upsampler = nn.Sequential(
        #     nn.ConvTranspose2d(in_chans, in_chans//2, kernel_size=2, stride=2),
        #     LayerNorm2d(in_chans//2),
        #     activation(),
        #     nn.ConvTranspose2d(in_chans//2, in_chans//2//2, kernel_size=2, stride=2),
        #     LayerNorm2d(in_chans//2//2),
        #     activation(),
        #     nn.ConvTranspose2d(in_chans//2//2, in_chans//2//2//2, kernel_size=2, stride=2),
        #     activation(),
        #     nn.ConvTranspose2d(in_chans//2//2//2, 1, kernel_size=2, stride=2)
        # )
        
        
    def forward(self, image_embeddings):
        '''该函数将image_embedding直接上采样到原尺寸的特征图'''
        # image_embedding: [B, C, H, W]
        H, W = self.cfg.PATCH_SIZE, self.cfg.PATCH_SIZE
        return F.interpolate(image_embeddings, size=(W, W), mode='bilinear', align_corners=False)

In [8]:


class R2RC(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()
        # overall cfg
        self.cfg = cfg
        self.image_size = cfg.PATCH_SIZE
        
        # data cfg
        self.register_buffer('mean', torch.Tensor(cfg.MEAN).view(-1, 1, 1), persistent=False) 
        self.register_buffer('std', torch.Tensor(cfg.STD).view(-1, 1, 1), persistent=False)
        
        # model cfg
        self.encoder_output_dim = cfg.ENCODER.ENCODER_OUTPUT_DIM
        self.vit_patch_size = cfg.ENCODER.VIT_PATCH_SIZE
        
        assert cfg.ENCODER.BACKBONE in ['SAM-vit-b', ...], f"{cfg.ENCODER.BACKBONE} is not a valid backbone! "
        if cfg.ENCODER.BACKBONE == 'SAM-vit-b':
            self.encoder_embed_dim = 768
            self.encoder_num_transformer_blocks = 12
            self.encoder_num_heads = 12
            self.encoder_global_attn_indexes = [2, 5, 8, 11]
        
        self.image_encoder = self._init_image_encoder()
        self.mask_decoder = self._init_mask_decoder()
        self.feature_upsampler = self._init_feature_upsampler()
        self.topo_decoder = self._init_topo_decoder()
        
        
        # TODO criterion
        
        # TODO metrics
        
        
        self._load_pretrained_weights()
        
        
    def _init_image_encoder(self):
        return ImageEncoderViT(
            img_size=self.image_size,
            patch_size=self.vit_patch_size,
            in_chans=3,
            embed_dim=self.encoder_embed_dim,
            depth=self.encoder_num_transformer_blocks,
            num_heads=self.encoder_num_heads,
            mlp_ratio=4.0,
            out_chans=self.encoder_output_dim,
            qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            act_layer=nn.GELU,
            use_abs_pos=True,
            use_rel_pos=True,
            window_size=14,
            global_attn_indexes=self.encoder_global_attn_indexes
        )
        
        
    def _init_mask_decoder(self):
        return MaskDecoder(self.encoder_output_dim)
    
    
    def _init_topo_decoder(self):
        return TopoDecoder(self.cfg)
    
    def _init_feature_upsampler(self):
        return FeatureUpsampler(self.cfg)
    
    
    def _resize_sam_pos_embedding(self, pretrained_state_dict):
        '''把SAM与训练权重里面的abs_pos_embed和rel_pos_embed都resize一下'''
        new_state_dict = {k:v for k, v in pretrained_state_dict.items()}
        pos_embed = new_state_dict['image_encoder.pos_embed']   # BHWC
        token_size = int(self.image_size // self.vit_patch_size)
        if pos_embed.shape[1] != token_size:    # != 1024/16
            # abs pos
            pos_embed = pos_embed.permute(0, 3, 1, 2)   # ->BCHW for interpolate
            pos_embed = F.interpolate(pos_embed, size=(token_size, token_size), mode='bilinear', align_corners=False)
            new_state_dict['image_encoder.pos_embed'] = pos_embed.permute(0, 2, 3, 1)
            # rel_pos
            rel_pos_key_pattern = '{}.attn.rel_pos'
            global_rel_pos_keys = []
            for idx in self.encoder_global_attn_indexes:
                for k in new_state_dict.keys():
                    if rel_pos_key_pattern.format(idx) in k:
                        global_rel_pos_keys.append(k)
            for k in global_rel_pos_keys:
                rel_pos_embed = new_state_dict[k]
                # XXX 把序列长度和通道数看成是HW矩阵以进行空间插值，这样做是否合理？是不是应该重新训练这一部分？
                H, W = rel_pos_embed.shape  # W 其实是对应的通道数
                rel_pos_embed = rel_pos_embed.unsqueeze(0).unsqueeze(0) # HW -> BCHW [1,1,H,W]
                rel_pos_embed = F.interpolate(rel_pos_embed, size=(2*token_size - 1, W), mode='bilinear', align_corners=False)
                new_state_dict[k] = rel_pos_embed[0, 0, :, :]
        return new_state_dict
    
    
    def _load_pretrained_weights(self):
        with open(self.cfg.ENCODER.SAM_CKPT_PATH, 'rb') as f:
            state_dict = torch.load(f)
            state_dict = self._resize_sam_pos_embedding(state_dict)
            
        new_state_dict = {}
        matched_names = []
        unmatched_names = []
        for n, p in self.named_parameters():   # name, param
            if n in state_dict and p.shape==state_dict[n].shape:
                new_state_dict[n] = state_dict[n]
                matched_names.append(n)
            else:
                unmatched_names.append(n)
        
        if self.cfg.dev_run:
            pprint("========== Matched names ==========")
            pprint(matched_names)
            print()
            pprint("xxxxxxxxxx Unmatched names xxxxxxxxxx")
            pprint(unmatched_names)
            
        self.load_state_dict(new_state_dict, strict=False)
            
        
    def forward(self, rgb, keypoints):
        # rgb: [B, H, W, C]
        # keypoints: [B, N_points, 2]
        x = rgb.permute(0, 3, 1, 2) # [B, C, H, W]
        x = (x - self.mean) / self.std
        
        image_embeddings = self.image_encoder(x)
        kpt_mask_logits = self.mask_decoder(image_embeddings)
        upsampled_features = self.feature_upsampler(image_embeddings)
        
        # TODO cropped_upsampled_features或许应该加上相对位置编码
        # TODO topo_decoder 还应该接受多个queries用以输出多个候选点
        # pred_next_nodes: [B, N, 3]
        # [
        #   [p1, Δx1, Δy1],
        #   [p2, Δx2, Δy2],
        #       ...
        #   [p6, Δx6, Δy6]
        # ]
        # BUG 注意由于预测的点是无序的，所以在计算损失时需要匹配一下
        pred_next_nodes_logits = self.topo_decoder(upsampled_features, image_embeddings, keypoints)
        
        return kpt_mask_logits, pred_next_nodes_logits
    
    
    def training_setp(self, batch, batch_idx):
        loss = None
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        pass
    
    
    def cfgure_optimizers(self):
        optimizer = None
        lr_scheduler = None
        
        return  {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
    
    
    

In [12]:


from utils import load_config

cfg_path = './config/R2RC.yml'
cfg = load_config(cfg_path)
cfg.dev_run = False


model = R2RC(cfg=cfg)

rgb = torch.rand(size=(4, 3, 512, 512))
image_embeddings = model.image_encoder(rgb)
upsampler_features = model.feature_upsampler(image_embeddings)
mask = model.mask_decoder(image_embeddings)
print(image_embeddings.shape)
print(upsampler_features.shape)
print(mask.shape)



  state_dict = torch.load(f)


torch.Size([4, 256, 32, 32])
torch.Size([4, 256, 512, 512])
torch.Size([4, 1, 512, 512])
