In [1]:
import torch
import numpy as np
import cv2
from utils.training_kits import stdout_to_tqdm, load_pretrained_state

from data import get_dataset
from config import get_config
from utils.visualization_tools import draw_heatmaps, draw_region_maps, draw_point, draw_bbox, draw_text,draw_centermap

cfg, DATASET = get_config("config/freihand/cfg_freihand_hg_ms_att.py")
cfg['batch_size'] = 1
cfg['workers'] = 1

from tensorboardX import SummaryWriter
import torchvision.utils as vutils

new_size = cfg["image_size"][0]
def draw_region_bbox(img, xywhc, color=(0, 0, 255)):
    cx, cy, w, h = xywhc[:4]
    x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
    x1 = int(max(0, x1))
    y1 = int(max(0, y1))
    x2 = int(min(x2, new_size))
    y2 = int(min(y2, new_size))
    img = draw_bbox(img, x1, y1, x2, y2, color=color)
    return img

def image_recovery(image):
    if image.ndim == 4: 
        image = image[0]
    image = image.permute(1, 2, 0).detach().numpy()
    m = np.array([0.485, 0.456, 0.406])
    s = np.array([0.229, 0.224, 0.225])
    image = image * s + m
    image *= 255
    image = image.astype(np.uint8) 
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    return image

def make_heatmaps(image, heatmaps):
    image = image_recovery(image)

    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    
    heatmaps = heatmaps.mul(255)\
                       .clamp(0, 255)\
                       .byte()\
                       .cpu().numpy()

    num_joints, height, width = heatmaps.shape
    image_resized = cv2.resize(image, (int(width), int(height)))

    image_grid = np.zeros((height, (num_joints+1)*width, 3), dtype=np.uint8)

    for j in range(num_joints):
        # add_joints(image_resized, joints[:, j, :])
        heatmap = heatmaps[j, :, :]
        colored_heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        image_fused = colored_heatmap*0.7 + image_resized*0.3

        width_begin = width * (j+1)
        width_end = width * (j+2)
        image_grid[:, width_begin:width_end, :] = image_fused

    image_grid[:, 0:width, :] = image_resized

    return image_grid


In [None]:
#  1、用于可视化用SimDR检测关键点的手部检测
from models.center_simDR import LiteHourglassNet as Network
from utils.result_parser import ResultParser

img, target_x, target_y, target_weight, centermap, centermask, bbox, gt_kpts = \
None, None, None, None, None, None, None, None
class TestPreds:
    def __init__(self, checkpoint="", is_cuda=True, ground_truth=False, exp_name=''):

        print("preparing data...")
        self.dataset, self.test_loader = get_dataset(cfg, DATASET, is_train=False, distributed=False)
        print("done!")

        if is_cuda:
            self.device = torch.device(0)
        else:
            self.device = torch.device('cpu')
        self.model = Network().to(self.device)

        self.ground_truth = ground_truth
        if checkpoint != "":
            print("loading state dict...")
            save_dict = torch.load(checkpoint, map_location=self.device)
            print(f"{save_dict.keys()=}")
            
            state, is_match = load_pretrained_state(self.model.state_dict(),
                                                    save_dict['state_dict'])
            # self.model.load_state_dict(save_dict["state_dict"])
            self.model.load_state_dict(state)
            print(f"done! {is_match=}")
        
        self.writer = SummaryWriter(log_dir='jupyter_log/'+ exp_name)
        self.parser = ResultParser()

    def test(self, n_img=10, show_hms=True, show_kpts=True):
        self.model.eval()
        with torch.no_grad():
            pck = 0
            if n_img == -1:
                n_img = len(self.test_loader)  # -1 整个数据集过一遍
            for i, meta in enumerate(self.test_loader):
                if i > n_img:
                    break
                start = cv2.getTickCount()  # 计时器起点
                img, target_x, target_y, target_weight, centermap, centermask, bbox, gt_kpts = meta
                
                if not self.ground_truth:
                    pred_centermap, pred_x, pred_y = self.model(img.to(self.device))
                else:
                    pred_centermap, pred_x, pred_y = centermap, target_x, target_y
         
                # 结果解析，得到原图关键点和边界框
                pred_kpts, pred_bboxes = self.parser.parse(pred_centermap, pred_x, pred_y)
                ap50, ap = self.parser.evaluate_ap(pred_bboxes, bbox)

                # 画出热图
                if show_hms:
                    out_centermap = draw_centermap(centermap)
                    out_centermap = [torch.tensor(out, dtype=torch.uint8) for out in out_centermap]
                    imgs = torch.stack(out_centermap, dim=0)
                    self.writer.add_image('centermap', imgs, i, dataformats='NHWC')

                # 画出关键点
                if show_kpts:
                    batch_xywh = pred_bboxes[0]
                    if batch_xywh is None:
                        print("没有找到目标框")
                        batch_xywh = [[4, 4, 2, 2, 0]]

                    for image, kpts, xywh in zip(img, pred_kpts, batch_xywh):  
                    # for image, kpts, xywh in zip(img, gt_kpts, batch_xywh):    
                        image = image.permute(1, 2, 0).detach().numpy()
                        m = np.array([0.485, 0.456, 0.406])
                        s = np.array([0.229, 0.224, 0.225])
                        image = image * s + m
                        image *= 255
                        image = image.astype(np.uint8) 
                        
                        # kpts = kpts.squeeze(dim=0).detach().cpu().numpy()
                        kpts = kpts[0].detach().cpu().numpy()
                        # print(f"{image.shape=}")
                        print(f"{bbox=}")
                        print(f"{xywh=}")
                        print(f"{i=} :{kpts=}")
                        print(F"{i=} :{gt_kpts[0, 0]=}")
                        print('*'* 100)
                        
                        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                        image_drawn = draw_point(img=image.copy(), keypoints=kpts)
                        image_drawn = draw_region_bbox(image_drawn, xywh)

                    end = cv2.getTickCount()  # 计时器终点
                    fps = round(cv2.getTickFrequency() / (end - start))
                    text = str(fps) + "fps"
                    img = draw_text(image_drawn, text, (15, 15, 20, 20))
                    # img = img[:,:,::-1]   # BGR to RGB
                    imgs = torch.stack([torch.tensor(image, dtype=torch.uint8),
                                        torch.tensor(image_drawn, dtype=torch.uint8)], dim=0)
                    self.writer.add_images('images', imgs, i, dataformats='NHWC')
                    
            pck = pck / (n_img+1)
            print(f"{n_img=}")
            print(f"{pck=}")
            self.writer.close()
            

In [2]:
#  2、用于可视化用热图检测关键点的手部检测
from utils.evaluation import evaluate_pck, evaluate_ap, get_coordinates_from_heatmap
from utils.result_parser import ResultParser
from models.pose_estimation import get_model
import os

class TestPreds:
    def __init__(self, checkpoint="", is_cuda=True, exp_name='cs', rm=False):

        print("preparing data...")
        self.dataset, self.test_loader = get_dataset(cfg, DATASET,
                                                     is_train=True,
                                                     distributed=False)
        print("done!")

        if is_cuda:
            self.device = torch.device(0)
        else:
            self.device = torch.device('cpu')
        self.model = get_model(cfg).to(self.device)
        
        self.with_simdr = cfg['simdr_split_ratio'] > 0

        if checkpoint != "":
            print("loading state dict...")
            save_dict = torch.load(checkpoint, map_location=self.device)
            print(f"{save_dict.keys()=}")
            
            state, is_match = load_pretrained_state(self.model.state_dict(),
                                                    save_dict['state_dict'])

            self.model.load_state_dict(state)
            print(f"done! {is_match=}")
        
        self.result_parser = ResultParser(cfg)
        
        logdir = os.path.join('./jupyter_log', exp_name)
        if rm and os.path.exists(logdir):
            files = os.listdir(logdir)
            files = [f for f in files if 'events' in files]
            for file in files:
                os.remove(os.path.join(logdir, file))
                
        self.writer = SummaryWriter(log_dir= logdir)

    def test(self, n_img=10, show_hms=True, show_kpts=True, show_cd=False):
        self.model.eval()
        with torch.no_grad():
            pck = 0
            if n_img == -1:
                n_img = len(self.test_loader)  # -1 整个数据集过一遍
            
            check_num = 0
            for i, (img, targets, target_weight, bbox, gt_kpts, target_x, target_y) \
                in enumerate(self.test_loader):     
                if i % 20 != 0:
                    continue
                
                check_num +=1    
                if check_num > n_img:
                    break
                    
                start = cv2.getTickCount()  # 计时器起点

                if show_cd:  # 显示循环检测的效果
                    img, targets, gt_kpts, target_x, target_y = \
                        self.dataset.generate_cd_gt(img, gt_kpts, bbox, target_weight)
                        
                    bbox = torch.tensor([[[img.shape[3] / 2,
                                           img.shape[2] / 2,
                                           img.shape[3],
                                           img.shape[2]]]], device=bbox.device)
                
                if self.with_simdr:
                    output, pred_x, pred_y = self.model(img.to(self.device))
                else:
                    output = self.model(img.to(self.device))
                    
                region = output[-1][:, -3:]
                hm_kpts = output[-1][:, :-3].to(self.device)

                pck += evaluate_pck(hm_kpts, targets[-1][:, :-3], bbox, thr=0.2).item()
                print(f"{pck=}")

                # 画出热图
                if show_hms:
                    img_region = make_heatmaps(img, region[0])
                    img_heatmap = make_heatmaps(img, hm_kpts[0])
                    
                    img_region  = torch.tensor(img_region, dtype=torch.uint8)
                    self.writer.add_image('region', img_region , i, dataformats='HWC')
                    
                    img_heatmap  = torch.tensor(img_heatmap, dtype=torch.uint8)
                    self.writer.add_image('heatmap', img_heatmap , i, dataformats='HWC')

                # 画出关键点
                if show_kpts:
                    _, _, batch_xywh = evaluate_ap(region, bbox, new_size, k=10, iou_thr=0.3)
                    batch_xywh = batch_xywh[0]
                    if batch_xywh is None:
                        print("没有找到目标框")
                        batch_xywh = [[4, 4, 2, 2, 0]]
                        
                    # ! 因为现在是单手，所以特殊处理一下
                    bbox = bbox[0]
                    gt_kpts = gt_kpts[:, 0]

                    first_time_kpt = self.result_parser.get_pred_kpt(hm_kpts, resized=True)
                    pred_bboxes = self.result_parser.get_pred_bbox(region)
                    second_time_kpt = self.result_parser.get_group_keypoints(self.model, img, pred_bboxes, hm_kpts)

                    for image, gk, ftk, stk, xywh, gt_xywh in zip(img, gt_kpts,
                                                          first_time_kpt, second_time_kpt,
                                                          batch_xywh, bbox):    
                        image = image_recovery(image)
                        image_gt = draw_results(image, gk, gt_xywh, gt_xywh)
                        img_first_pred = draw_results(image, ftk, xywh, gt_xywh)
                        img_scecond_pred = draw_results(image, stk, xywh, gt_xywh)

                    end = cv2.getTickCount()  # 计时器终点
                    fps = round(cv2.getTickFrequency() / (end - start))
                    text = str(fps) + "fps"
                    img = draw_text(img_first_pred, text, (15, 15, 20, 20))
                    # img = img[:,:,::-1]   # BGR to RGB
                    imgs = torch.stack([torch.tensor(image, dtype=torch.uint8),
                                        torch.tensor(image_gt, dtype=torch.uint8),
                                        torch.tensor(img_first_pred, dtype=torch.uint8),
                                        torch.tensor(img_scecond_pred, dtype=torch.uint8),
                                        ], dim=0)
                    # img_grid = vutils.make_grid(imgs, normalize=True, scale_each=True, nrow=2)
                    self.writer.add_images('images', imgs, i, dataformats='NHWC')
                    
            pck = pck / (n_img+1)
            print(f"{n_img=}")
            print(f"{pck=}")
            self.writer.close()

def draw_results(image, kpts, xywh, gt_xywh):
    """在图片上画出预测关键点、预测框和真值框

    Args:
        image (numpy): (h, w, c)
        kpts (tensor): (1, n_joints, 3)
        xywh (tensor): (5,) (cx, cy, w, h, score)
        gt_xywh (tensor): (4,) (cx, cy, w, h)
    """
    kpts = kpts.squeeze(dim=0).detach().cpu().numpy()
    image_drawn = draw_point(img=image.copy(), keypoints=kpts)
    image_drawn = draw_region_bbox(image_drawn, xywh, (255, 0, 0))
    image_drawn = draw_region_bbox(image_drawn, gt_xywh, (0, 255, 0))
    return image_drawn
    

In [4]:
# path = "./checkpoint/MSRB-D-DW-PELEE/1HG-ME-att-c256/2021-12-27/0.981_PCK_47epoch.pt"
path = "output/freihand/steplr/freihand_plus_hg_ms_att/2022-03-20/best_pck.pt"
t = TestPreds(checkpoint=path, is_cuda=False, exp_name='freihand_plus/test', rm=True)
t.test(n_img=300, show_hms=True, show_kpts=True, show_cd=False)
# tensorboard --samples_per_plugin scalars=100,images=100 --logdir "./jupyter_log/<exp_name>"

preparing data...
sample number of training dataset:  104192
done!
loading state dict...
save_dict.keys()=dict_keys(['epoch', 'PCK', 'AP', 'AP50', 'state_dict', 'optimizer', 'config'])
done! is_match=True
pck=1.0
pck=1.0
pck=1.9047619104385376
pck=2.0000000074505806
pck=2.0000000074505806
pck=2.0000000074505806
pck=2.5238095596432686
pck=2.5238095596432686
pck=2.5238095596432686
pck=2.952380992472172
pck=3.9047619476914406
pck=3.9047619476914406
pck=4.238095290958881
pck=4.238095290958881
pck=5.238095290958881
pck=5.238095290958881
pck=5.238095290958881
pck=6.19047624617815
pck=7.19047624617815
pck=8.19047624617815
pck=8.285714343190193
pck=9.285714343190193
pck=9.285714343190193
pck=9.333333391696215
pck=9.666666734963655
pck=9.761904831975698
pck=10.761904831975698
pck=10.761904831975698
pck=11.047619130462408
pck=11.333333428949118
pck=12.238095339387655
pck=12.333333436399698
pck=13.285714391618967
pck=14.238095346838236
pck=14.238095346838236
pck=15.238095346838236
pck=16.09523821