In [None]:
import torch, mmengine, mmcv, mmdet, mmdet3d, spconv
print(torch.__version__)
print(torch.version.cuda)
print(mmengine.__version__)
print(mmcv.__version__)
print(mmdet.__version__)
print(mmdet3d.__version__)
print(spconv.__version__)

In [None]:
import argparse
import logging
import os
import os.path as osp
import torch
import mmcv
import numpy as np
import math

from mmengine.config import Config, DictAction
from mmengine.logging import print_log
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
from mmengine import fileio

from mmdet3d.utils import replace_ceph_backend
from projects.mmdet3d_plugin.models.detectors import CmtDetector
import time
from mmengine.structures import InstanceData
from mmdet.models.layers import inverse_sigmoid

In [None]:
cfg = Config.fromfile('projects/configs/fusion/my_cmt_kitti.py')
# cfg = Config.fromfile('../mmdetection3d/projects/BEVFusion/configs/my_bevfusion.py')
cfg.work_dir = osp.abspath('./work_dirs')
runner = Runner.from_cfg(cfg)

In [None]:
runner.model.train()

In [None]:
def pos2embed(pos, num_pos_feats=128, temperature=10000):
    scale = 2 * math.pi
    pos = pos * scale
    dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
    dim_t = 2 * (dim_t // 2) / num_pos_feats + 1
    pos_x = pos[..., 0, None] / dim_t
    pos_y = pos[..., 1, None] / dim_t
    pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
    pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
    posemb = torch.cat((pos_y, pos_x), dim=-1)
    return posemb

In [21]:
with torch.no_grad():
    data_batch = next(iter(runner.train_dataloader))
    data_batch = runner.model.data_preprocessor(data_batch, training=False)
    batch_inputs_dict = data_batch['inputs']
    batch_data_samples = data_batch['data_samples']
    imgs = batch_inputs_dict.get('imgs', None)
    points = batch_inputs_dict.get('points', None)
    img_metas = [item.metainfo for item in batch_data_samples]
    gt_bboxes_3d = [item.get('gt_instances_3d')['bboxes_3d'] for item in batch_data_samples]
    gt_labels_3d = [item.get('gt_instances_3d')['labels_3d'] for item in batch_data_samples]

    img_feats = runner.model.extract_img_feat(imgs, img_metas)
    voxels, num_points, coors = runner.model.voxelize(points)
    voxel_features = runner.model.pts_voxel_encoder(voxels, num_points, coors)
    batch_size = coors[-1, 0] + 1
    x1 = runner.model.pts_middle_encoder(voxel_features, coors, batch_size)
    x2 = runner.model.pts_backbone(x1)
    if runner.model.with_pts_neck:
        pts_feats = runner.model.pts_neck(x2)

    ret_dicts = []
    x3 = runner.model.pts_bbox_head.shared_conv(pts_feats[0])
    reference_points = runner.model.pts_bbox_head.reference_points.weight
    reference_points, attn_mask, mask_dict = runner.model.pts_bbox_head.prepare_for_dn(x3.shape[0], reference_points, img_metas)
    
    rv_pos_embeds = runner.model.pts_bbox_head._rv_pe(img_feats[0], img_metas)
    bev_pos_embeds = runner.model.pts_bbox_head.bev_embedding(pos2embed(runner.model.pts_bbox_head.coords_bev.to(x3.device), num_pos_feats=runner.model.pts_bbox_head.hidden_dim))
    
    bev_query_embeds, rv_query_embeds = runner.model.pts_bbox_head.query_embed(reference_points, img_metas)
    query_embeds = bev_query_embeds + rv_query_embeds

    outs_dec, _ = runner.model.pts_bbox_head.transformer(
                        x3, img_feats[0], query_embeds,
                        bev_pos_embeds, rv_pos_embeds,
                        attn_masks=attn_mask
                    )
    outs_dec = torch.nan_to_num(outs_dec)
    reference = inverse_sigmoid(reference_points.clone())

In [41]:
x3.shape, reference_points.shape

(torch.Size([4, 256, 100, 100]), torch.Size([4, 1130, 3]))

In [42]:
rv_pos_embeds.shape, bev_pos_embeds.shape, bev_query_embeds.shape, rv_query_embeds.shape, query_embeds.shape

(torch.Size([4, 20, 60, 256]),
 torch.Size([10000, 256]),
 torch.Size([4, 1130, 256]),
 torch.Size([4, 1130, 256]),
 torch.Size([4, 1130, 256]))

In [45]:
len(outs_dec), outs_dec[0].shape

(6, torch.Size([4, 1130, 256]))

In [None]:
img = imgs[2][0].permute(1, 2, 0).cpu().numpy()
img = mmcv.imdenormalize(img, mean=np.array([103.530, 116.280, 123.675]), std=np.array([57.375, 57.120, 58.395]), to_bgr=False)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(img/255)

In [None]:
gt_bboxes_3d[0]

In [None]:
img_feats[0].shape, img_feats[1].shape

In [None]:
voxel_features.shape, x1.shape, x2[0].shape, x2[1].shape, x3[0].shape

In [None]:
# 训练过程
for data_batch in runner.train_dataloader:
    data_batch = runner.model.data_preprocessor(data_batch, training=True)
    if isinstance(data_batch, dict):
        losses = runner.model(**data_batch, mode='loss')
    elif isinstance(data_batch, (list, tuple)):
        losses = runner.model(*data_batch, mode='loss')
    else:
        raise TypeError()
    break

In [None]:
losses

In [None]:
data_batch['inputs'], data_batch['data_samples'][0]

In [None]:
runner.model.eval()

In [None]:
# 验证过程
for data_batch in runner.val_dataloader:
    data_batch = runner.model.data_preprocessor(data_batch, training=False)
    if isinstance(data_batch, dict):
        outputs = runner.model(**data_batch, mode='predict')
    elif isinstance(data_batch, (list, tuple)):
        outputs = runner.model(**data_batch, mode='predict')
    else:
        raise TypeError()
    runner.val_evaluator.process(data_samples=outputs, data_batch=data_batch)
    break
# with torch.no_grad():
#     metrics = runner.val_evaluator.evaluate(len(runner.val_dataloader.dataset))

In [None]:
data_batch['inputs']['imgs'][0][:,300:400,300:400]

In [None]:
outputs[0]

In [None]:
runner.val_dataloader

In [None]:
time_start = time.time()
num = 0
with torch.no_grad():
    for data_batch in runner.val_dataloader:
        data_batch = runner.model.data_preprocessor(data_batch, training=False)
        if isinstance(data_batch, dict):
            outputs = runner.model(**data_batch, mode='predict')
        elif isinstance(data_batch, (list, tuple)):
            outputs = runner.model(**data_batch, mode='predict')
        else:
            raise TypeError()
        num += 1
        if num == 100:
            break
print(time.time() - time_start)

In [None]:
time_start = time.time()
with torch.no_grad():
    data_batch_raw = next(iter(runner.val_dataloader))
    for _ in range(100):
        data_batch = runner.model.data_preprocessor(data_batch_raw, training=False)
        if isinstance(data_batch, dict):
            outputs = runner.model(**data_batch, mode='predict')
        elif isinstance(data_batch, (list, tuple)):
            outputs = runner.model(**data_batch, mode='predict')
        else:
            raise TypeError()
print(time.time() - time_start)

In [None]:
with torch.no_grad():
    data_batch = next(iter(runner.val_dataloader))
    data_batch = runner.model.data_preprocessor(data_batch, training=False)
    batch_inputs_dict = data_batch['inputs']
    batch_data_samples = data_batch['data_samples']
    imgs = batch_inputs_dict.get('imgs', None)
    points = batch_inputs_dict.get('points', None)
    img_metas = [item.metainfo for item in batch_data_samples]

    time_start = time.time()
    for _ in range(100):
        img_feats = runner.model.extract_img_feat(imgs, img_metas)
    
    midedle1 = time.time()
    for _ in range(100):
        pts_feats = runner.model.extract_pts_feat(points, img_feats, img_metas)

    midedle2 = time.time()
    for _ in range(100):
        outs = runner.model.pts_bbox_head(pts_feats, img_feats, img_metas)

    middle3 = time.time()
    for _ in range(100):
        bbox_list = runner.model.pts_bbox_head.get_bboxes(
            outs, img_metas, rescale=False)
        
        # bbox_results = []
        # for bboxes, scores, labels in bbox_list:
        #     results = InstanceData()
        #     results.bboxes_3d = bboxes.to('cpu')
        #     results.scores_3d = scores.cpu()
        #     results.labels_3d = labels.cpu()
        #     bbox_results.append(results)
        # detsamples = runner.model.add_pred_to_datasample(batch_data_samples,
        #                                             data_instances_3d = bbox_results,
        #                                             data_instances_2d = None)
    end = time.time()
    print(midedle1 - time_start, midedle2 - midedle1, middle3 - midedle2, end - middle3)