In [1]:
import torch
import numpy as np
import cv2
from utils.training_kits import stdout_to_tqdm, load_pretrained_state

from data import get_dataset
from config.config import DATASET, config_dict
from utils.visualization_tools import draw_heatmaps, draw_region_maps, draw_point, draw_bbox, draw_text,draw_centermap

from tensorboardX import SummaryWriter
import torchvision.utils as vutils

config_dict['batch_size'] = 1
config_dict['workers'] = 1

new_size = config_dict["image_size"][0]
def draw_region_bbox(img, xywhc, color=(0, 0, 255)):
    cx, cy, w, h = xywhc[:4]
    x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
    x1 = int(max(0, x1))
    y1 = int(max(0, y1))
    x2 = int(min(x2, new_size))
    y2 = int(min(y2, new_size))
    img = draw_bbox(img, x1, y1, x2, y2, color=color)
    return img


In [None]:
#  1、用于可视化用SimDR检测关键点的手部检测
from models.center_simDR import LiteHourglassNet as Network
from utils.CenterSimDRParser import ResultParser

img, target_x, target_y, target_weight, centermap, centermask, bbox, gt_kpts = \
None, None, None, None, None, None, None, None
class TestPreds:
    def __init__(self, checkpoint="", is_cuda=True, ground_truth=False, exp_name=''):

        print("preparing data...")
        self.dataset, self.test_loader = get_dataset(set_type='test')
        print("done!")

        if is_cuda:
            self.device = torch.device(0)
        else:
            self.device = torch.device('cpu')
        self.model = Network().to(self.device)

        self.ground_truth = ground_truth
        if checkpoint != "":
            print("loading state dict...")
            save_dict = torch.load(checkpoint, map_location=self.device)
            print(f"{save_dict.keys()=}")
            
            state, is_match = load_pretrained_state(self.model.state_dict(),
                                                    save_dict['state_dict'])
            # self.model.load_state_dict(save_dict["state_dict"])
            self.model.load_state_dict(state)
            print(f"done! {is_match=}")
        
        self.writer = SummaryWriter(log_dir='jupyter_log/'+ exp_name)
        self.parser = ResultParser()

    def test(self, n_img=10, show_hms=True, show_kpts=True):
        self.model.eval()
        with torch.no_grad():
            pck = 0
            if n_img == -1:
                n_img = len(self.test_loader)  # -1 整个数据集过一遍
            for i, meta in enumerate(self.test_loader):
                if i > n_img:
                    break
                start = cv2.getTickCount()  # 计时器起点
                img, target_x, target_y, target_weight, centermap, centermask, bbox, gt_kpts = meta
                
                if not self.ground_truth:
                    pred_centermap, pred_x, pred_y = self.model(img.to(self.device))
                else:
                    pred_centermap, pred_x, pred_y = centermap, target_x, target_y
         
                # 结果解析，得到原图关键点和边界框
                pred_kpts, pred_bboxes = self.parser.parse(pred_centermap, pred_x, pred_y)
                ap50, ap = self.parser.evaluate_ap(pred_bboxes, bbox)

                # 画出热图
                if show_hms:
                    out_centermap = draw_centermap(centermap)
                    out_centermap = [torch.tensor(out, dtype=torch.uint8) for out in out_centermap]
                    imgs = torch.stack(out_centermap, dim=0)
                    self.writer.add_image('centermap', imgs, i, dataformats='NHWC')

                # 画出关键点
                if show_kpts:
                    batch_xywh = pred_bboxes[0]
                    if batch_xywh is None:
                        print("没有找到目标框")
                        batch_xywh = [[4, 4, 2, 2, 0]]

                    for image, kpts, xywh in zip(img, pred_kpts, batch_xywh):  
                    # for image, kpts, xywh in zip(img, gt_kpts, batch_xywh):    
                        image = image.permute(1, 2, 0).detach().numpy()
                        m = np.array([0.485, 0.456, 0.406])
                        s = np.array([0.229, 0.224, 0.225])
                        image = image * s + m
                        image *= 255
                        image = image.astype(np.uint8) 
                        
                        # kpts = kpts.squeeze(dim=0).detach().cpu().numpy()
                        kpts = kpts[0].detach().cpu().numpy()
                        # print(f"{image.shape=}")
                        print(f"{bbox=}")
                        print(f"{xywh=}")
                        print(f"{i=} ：{kpts=}")
                        print(F"{i=} ：{gt_kpts[0, 0]=}")
                        print('*'* 100)
                        
                        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                        image_drawn = draw_point(img=image.copy(), keypoints=kpts)
                        image_drawn = draw_region_bbox(image_drawn, xywh)

                    end = cv2.getTickCount()  # 计时器终点
                    fps = round(cv2.getTickFrequency() / (end - start))
                    text = str(fps) + "fps"
                    img = draw_text(image_drawn, text, (15, 15, 20, 20))
                    # img = img[:,:,::-1]   # BGR to RGB
                    imgs = torch.stack([torch.tensor(image, dtype=torch.uint8),
                                        torch.tensor(image_drawn, dtype=torch.uint8)], dim=0)
                    self.writer.add_images('images', imgs, i, dataformats='NHWC')
                    
            pck = pck / (n_img+1)
            print(f"{n_img=}")
            print(f"{pck=}")
            self.writer.close()
            

In [2]:
#  2、用于可视化用热图检测关键点的手部检测
from models.hourglass_SA import HourglassNet_SA as Network
from utils.evaluation import evaluate_pck, evaluate_ap, get_coordinates_from_heatmap
from utils.CenterSimDRParser import ResultParser


class TestPreds:
    def __init__(self, checkpoint="", is_cuda=True, ground_truth=False, exp_name='cs'):

        print("preparing data...")
        self.dataset, self.test_loader = get_dataset(set_type='test')
        print("done!")

        if is_cuda:
            self.device = torch.device(0)
        else:
            self.device = torch.device('cpu')
        self.model = Network().to(self.device)

        self.ground_truth = ground_truth
        if checkpoint != "":
            print("loading state dict...")
            save_dict = torch.load(checkpoint, map_location=self.device)
            print(f"{save_dict.keys()=}")
            
            state, is_match = load_pretrained_state(self.model.state_dict(),
                                                    save_dict['state_dict'])
            # self.model.load_state_dict(save_dict["state_dict"])
            self.model.load_state_dict(state)
            print(f"done! {is_match=}")
        
        self.result_parser = ResultParser()
        self.writer = SummaryWriter(log_dir= 'jupyter_log/' + exp_name)

    def test(self, n_img=10, show_hms=True, show_kpts=True, show_cd=False):
        self.model.eval()
        with torch.no_grad():
            pck = 0
            if n_img == -1:
                n_img = len(self.test_loader)  # -1 整个数据集过一遍
            for i, (img, target_x, target_y, target_weight, kpts_hm, bbox, gt_kpts) in enumerate(self.test_loader):
                if i > n_img:
                    break
                start = cv2.getTickCount()  # 计时器起点

                if show_cd:  # 显示循环检测的效果
                    img, target_x, target_y, kpts_hm, gt_kpts = self.dataset.generate_cd_gt(img, gt_kpts, bbox, target_weight)
                    bbox = torch.tensor([[[img.shape[3] / 2, img.shape[2] / 2,
                                     img.shape[3], img.shape[2]]]], device=bbox.device)
  

                target = kpts_hm
                region = target[:, :3]
                hm_kpts = target[:, 3:].to(self.device)

                if not self.ground_truth:
                    hms_list, pred_x, pred_y  = self.model(img.to(self.device))
                    hm_kpts = hms_list[-1][:, 3:]
                    region = hms_list[-1][:, :3]
                
                pck += evaluate_pck(hm_kpts, target[:, 3:], bbox, thr=0.2).item()
                print(f"{pck=}")

                # 画出热图
                if show_hms:
                    draw_region_maps(region)
                    draw_heatmaps(hm_kpts)

                # 画出关键点
                if show_kpts:
                    # batch_xywh = cs_from_region_map(batch_region_maps=region, k=1, thr=0.1)
                    _, _, batch_xywh = evaluate_ap(region, bbox, k=10, conf_thr=0.1)
                    # print(f"{batch_xywh=}")
                    batch_xywh = batch_xywh[0]
                    if batch_xywh is None:
                        print("没有找到目标框")
                        batch_xywh = [[4, 4, 2, 2, 0]]
                        
                    # ! 因为现在是单手，所以特殊处理一下
                    bbox = bbox[0]
                    gt_kpts = gt_kpts[:, 0]

                    first_time_kpt = self.result_parser.get_pred_kpt(hm_kpts)
                    first_time_kpt[:, :, :2] *= torch.tensor([4, 4], device=first_time_kpt.device)
                    pred_bboxes = self.result_parser.get_pred_bbox(region)
                    second_time_kpt = self.result_parser.get_group_keypoints(self.model, img, pred_bboxes, hm_kpts)

                    for image, gk, ftk, stk, xywh, gt_xywh in zip(img, gt_kpts,
                                                          first_time_kpt, second_time_kpt,
                                                          batch_xywh, bbox):    
                        image = image.permute(1, 2, 0).detach().numpy()
                        m = np.array([0.485, 0.456, 0.406])
                        s = np.array([0.229, 0.224, 0.225])
                        image = image * s + m
                        image *= 255
                        image = image.astype(np.uint8) 
                        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                        image_gt = draw_results(image, gk, gt_xywh, gt_xywh)
                        image_drawn1 = draw_results(image, ftk, xywh, gt_xywh)
                        image_drawn2 = draw_results(image, stk, xywh, gt_xywh)

                    end = cv2.getTickCount()  # 计时器终点
                    fps = round(cv2.getTickFrequency() / (end - start))
                    text = str(fps) + "fps"
                    img = draw_text(image_drawn1, text, (15, 15, 20, 20))
                    # img = img[:,:,::-1]   # BGR to RGB
                    imgs = torch.stack([torch.tensor(image, dtype=torch.uint8),
                                        torch.tensor(image_gt, dtype=torch.uint8),
                                        torch.tensor(image_drawn1, dtype=torch.uint8),
                                        torch.tensor(image_drawn2, dtype=torch.uint8),
                                        ], dim=0)
                    # img_grid = vutils.make_grid(imgs, normalize=True, scale_each=True, nrow=2)
                    self.writer.add_images('images', imgs, i, dataformats='NHWC')
                    
            pck = pck / (n_img+1)
            print(f"{n_img=}")
            print(f"{pck=}")
            self.writer.close()

def draw_results(image, kpts, xywh, gt_xywh):
    """在图片上画出预测关键点、预测框和真值框

    Args:
        image (numpy): (h, w, c)
        kpts (tensor): (1, n_joints, 3)
        xywh (tensor): (5,) (cx, cy, w, h, score)
        gt_xywh (tensor): (4,) (cx, cy, w, h)
    """
    kpts = kpts.squeeze(dim=0).detach().cpu().numpy()
    image_drawn = draw_point(img=image.copy(), keypoints=kpts)
    image_drawn = draw_region_bbox(image_drawn, xywh, (255, 0, 0))
    image_drawn = draw_region_bbox(image_drawn, gt_xywh, (0, 255, 0))
    return image_drawn
    

In [3]:
# path = "./checkpoint/MSRB-D-DW-PELEE/1HG-ME-att-c256/2021-12-27/0.981_PCK_47epoch.pt"
path = "checkpoint/final_ME-att/ls/1HG-ME-att-c128-h4-k2-o64-gtbbox-no_augment/2022-03-10/72.566_AP_52epoch.pt"
t = TestPreds(checkpoint=path, is_cuda=False, ground_truth=False, exp_name='gtbbox/r3')
t.test(n_img=100, show_hms=False, show_kpts=True, show_cd=False)
# tensorboard --samples_per_plugin scalars=100,images=100 --logdir "./jupyter_log/r_gtbbox/"

preparing data...
sample number of testing dataset:  13024
done!
loading state dict...
save_dict.keys()=dict_keys(['epoch', 'lr', 'loss', 'mPCK', 'ap', 'state_dict', 'optimizer', 'config'])
done! is_match=True
pck=1.0
pck=1.9047619104385376
pck=2.9047619104385376
pck=3.9047619104385376
pck=4.904761910438538
pck=5.809523820877075
pck=6.809523820877075
pck=7.809523820877075
pck=7.809523820877075
pck=8.714285731315613
pck=9.61904764175415
pck=9.666666690260172
pck=10.476190511137247
pck=11.380952421575785
pck=12.238095287233591
pck=13.19047624245286
pck=14.19047624245286
pck=15.19047624245286
pck=16.19047624245286
pck=17.19047624245286
pck=17.19047624245286
pck=17.19047624245286
pck=17.19047624245286
pck=18.19047624245286
pck=19.19047624245286
pck=19.19047624245286
pck=20.19047624245286
pck=21.14285719767213
pck=22.047619108110666
pck=22.333333406597376
pck=22.333333406597376
pck=23.333333406597376
pck=24.238095317035913
pck=24.571428660303354
pck=24.666666757315397
pck=24.666666757315397

In [15]:
j = np.random.rand(2, 4, 3)
j

array([[[    0.65047,    0.019208,     0.41188],
        [    0.88008,     0.91811,     0.82871],
        [    0.36613,      0.6997,    0.056114],
        [    0.39086,     0.24489,     0.30793]],

       [[   0.046529,     0.69157,    0.043093],
        [    0.17521,     0.13193,     0.95549],
        [    0.70705,     0.86237,    0.061276],
        [     0.8503,     0.42167,     0.97655]]])

In [16]:
np.min(j, axis=1)

array([[    0.36613,    0.019208,    0.056114],
       [   0.046529,     0.13193,    0.043093]])