In [1]:
import torch
import numpy as np
import cv2
from utils.training_kits import stdout_to_tqdm, load_pretrained_state

from data import get_dataset
from config.config import DATASET, config_dict
from utils.visualization_tools import draw_heatmaps, draw_region_maps, draw_point, draw_bbox, draw_text,draw_centermap

from tensorboardX import SummaryWriter
import torchvision.utils as vutils

config_dict['batch_size'] = 1
config_dict['workers'] = 1

new_size = config_dict["image_size"][0]
def draw_region_bbox(img, xywhc):
    cx, cy, w, h = xywhc[:4]
    x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
    x1 = int(max(0, x1))
    y1 = int(max(0, y1))
    x2 = int(min(x2, new_size))
    y2 = int(min(y2, new_size))
    img = draw_bbox(img, x1, y1, x2, y2)
    return img


In [8]:
#  1、用于可视化用SimDR检测关键点的手部检测
from models.center_simDR import LiteHourglassNet as Network
from utils.CenterSimDRParser import ResultParser

img, target_x, target_y, target_weight, centermap, centermask, bbox, gt_kpts = \
None, None, None, None, None, None, None, None
class TestPreds:
    def __init__(self, checkpoint="", is_cuda=True, ground_truth=False, exp_name=''):

        print("preparing data...")
        self.dataset, self.test_loader = get_dataset(set_type='test')
        print("done!")

        if is_cuda:
            self.device = torch.device(0)
        else:
            self.device = torch.device('cpu')
        self.model = Network().to(self.device)

        self.ground_truth = ground_truth
        if checkpoint != "":
            print("loading state dict...")
            save_dict = torch.load(checkpoint, map_location=self.device)
            print(f"{save_dict.keys()=}")
            
            state, is_match = load_pretrained_state(self.model.state_dict(),
                                                    save_dict['state_dict'])
            # self.model.load_state_dict(save_dict["state_dict"])
            self.model.load_state_dict(state)
            print(f"done! {is_match=}")
        
        self.writer = SummaryWriter(log_dir='jupyter_log/'+ exp_name)
        self.parser = ResultParser()

    def test(self, n_img=10, show_hms=True, show_kpts=True):
        self.model.eval()
        with torch.no_grad():
            pck = 0
            if n_img == -1:
                n_img = len(self.test_loader)  # -1 整个数据集过一遍
            for i, meta in enumerate(self.test_loader):
                if i > n_img:
                    break
                start = cv2.getTickCount()  # 计时器起点
                img, target_x, target_y, target_weight, centermap, centermask, bbox, gt_kpts = meta
                
                if not self.ground_truth:
                    pred_centermap, pred_x, pred_y = self.model(img.to(self.device))
                else:
                    pred_centermap, pred_x, pred_y = centermap, target_x, target_y
         
                # 结果解析，得到原图关键点和边界框
                pred_kpts, pred_bboxes = self.parser.parse(pred_centermap, pred_x, pred_y)
                ap50, ap = self.parser.evaluate_ap(pred_bboxes, bbox)

                # 画出热图
                if show_hms:
                    out_centermap = draw_centermap(centermap)
                    out_centermap = [torch.tensor(out, dtype=torch.uint8) for out in out_centermap]
                    imgs = torch.stack(out_centermap, dim=0)
                    self.writer.add_image('centermap', imgs, i, dataformats='NHWC')

                # 画出关键点
                if show_kpts:
                    batch_xywh = pred_bboxes[0]
                    if batch_xywh is None:
                        print("没有找到目标框")
                        batch_xywh = [[4, 4, 2, 2, 0]]

                    for image, kpts, xywh in zip(img, pred_kpts, batch_xywh):  
                    # for image, kpts, xywh in zip(img, gt_kpts, batch_xywh):    
                        image = image.permute(1, 2, 0).detach().numpy()
                        m = np.array([0.485, 0.456, 0.406])
                        s = np.array([0.229, 0.224, 0.225])
                        image = image * s + m
                        image *= 255
                        image = image.astype(np.uint8) 
                        
                        # kpts = kpts.squeeze(dim=0).detach().cpu().numpy()
                        kpts = kpts[0].detach().cpu().numpy()
                        # print(f"{image.shape=}")
                        print(f"{bbox=}")
                        print(f"{xywh=}")
                        print(f"{i=} ：{kpts=}")
                        print(F"{i=} ：{gt_kpts[0, 0]=}")
                        print('*'* 100)
                        
                        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                        image_drawn = draw_point(img=image.copy(), keypoints=kpts)
                        image_drawn = draw_region_bbox(image_drawn, xywh)

                    end = cv2.getTickCount()  # 计时器终点
                    fps = round(cv2.getTickFrequency() / (end - start))
                    text = str(fps) + "fps"
                    img = draw_text(image_drawn, text, (15, 15, 20, 20))
                    # img = img[:,:,::-1]   # BGR to RGB
                    imgs = torch.stack([torch.tensor(image, dtype=torch.uint8),
                                        torch.tensor(image_drawn, dtype=torch.uint8)], dim=0)
                    self.writer.add_images('images', imgs, i, dataformats='NHWC')
                    
            pck = pck / (n_img+1)
            print(f"{n_img=}")
            print(f"{pck=}")
            self.writer.close()
            

In [7]:
#  2、用于可视化用热图检测关键点的手部检测
from models.hourglass_SA import HourglassNet_SA as Network
from utils.evaluation import evaluate_pck, evaluate_ap, get_coordinates_from_heatmap


class TestPreds:
    def __init__(self, checkpoint="", is_cuda=True, ground_truth=False):

        print("preparing data...")
        self.dataset, self.test_loader = get_dataset(set_type='test')
        print("done!")

        if is_cuda:
            self.device = torch.device(0)
        else:
            self.device = torch.device('cpu')
        self.model = Network().to(self.device)

        self.ground_truth = ground_truth
        if checkpoint != "":
            print("loading state dict...")
            save_dict = torch.load(checkpoint, map_location=self.device)
            print(f"{save_dict.keys()=}")
            
            state, is_match = load_pretrained_state(self.model.state_dict(),
                                                    save_dict['state_dict'])
            # self.model.load_state_dict(save_dict["state_dict"])
            self.model.load_state_dict(state)
            print(f"done! {is_match=}")
        
        self.writer = SummaryWriter(log_dir='jupyter_log')

    def test(self, n_img=10, show_hms=True, show_kpts=True):
        self.model.eval()
        with torch.no_grad():
            pck = 0
            if n_img == -1:
                n_img = len(self.test_loader)  # -1 整个数据集过一遍
            for i, (img, hm_target, hm_weight, label, bbox) in enumerate(self.test_loader):
                if i > n_img:
                    break
                # print(f"{bbox=}")
                # print(f"{bbox.shape=}")

                start = cv2.getTickCount()  # 计时器起点

                target = hm_target[-1]
                region = target[:, :3]
                hm_kpts = target[:, 3:].to(self.device)
                mask = hm_kpts[:, 0]

                if not self.ground_truth:
                    hms_list = self.model(img.to(self.device))  # [22, 22, 22, 44, 88]
                    hm_kpts = hms_list[-1][:, 3:]
                    region = hms_list[-1][:, :3]
                    
                pck += evaluate_pck(hm_kpts, target[:, 3:], bbox, thr=0.2).item()
                print(f"{pck=}")

                # 画出热图
                if show_hms:
                    draw_region_maps(region)
                    draw_heatmaps(mask)
                    draw_heatmaps(hm_kpts)

                # 画出关键点
                if show_kpts:
                    # batch_xywh = cs_from_region_map(batch_region_maps=region, k=1, thr=0.1)
                    _, _, batch_xywh = evaluate_ap(region, bbox, k=10, conf_thr=0.1)
                    # print(f"{batch_xywh=}")
                    batch_xywh = batch_xywh[0]
                    if batch_xywh is None:
                        print("没有找到目标框")
                        batch_xywh = [[4, 4, 2, 2, 0]]

                    batch_kpts, _ = get_coordinates_from_heatmap(hm_kpts)
                    # batch_kpts = self.get_coordinates(hm_kpts[:, 4:])
                    # print(f"{batch_kpts.shape=}")

                    heatmaps_size = hm_kpts.shape[-1]
                    batch_kpts[..., :2] = batch_kpts[..., :2] * new_size / heatmaps_size  # scale to original size


                    for image, kpts, xywh in zip(img, batch_kpts, batch_xywh):    
                        image = image.permute(1, 2, 0).detach().numpy()
                        m = np.array([0.485, 0.456, 0.406])
                        s = np.array([0.229, 0.224, 0.225])
                        image = image * s + m
                        image *= 255
                        image = image.astype(np.uint8) 
                        
                        kpts = kpts.squeeze(dim=0).detach().cpu().numpy()
                        print(f"{image.shape=}")
                        print(f"{kpts.shape=}")
                        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                        image_drawn = draw_point(img=image.copy(), keypoints=kpts)
                        image_drawn = self.draw_region_bbox(image_drawn, xywh)

                    end = cv2.getTickCount()  # 计时器终点
                    fps = round(cv2.getTickFrequency() / (end - start))
                    text = str(fps) + "fps"
                    img = draw_text(image_drawn, text, (15, 15, 20, 20))
                    # img = img[:,:,::-1]   # BGR to RGB
                    imgs = torch.stack([torch.tensor(image, dtype=torch.uint8),
                                        torch.tensor(image_drawn, dtype=torch.uint8)], dim=0)
                    # img_grid = vutils.make_grid(imgs, normalize=True, scale_each=True, nrow=2)
                    self.writer.add_images('images', imgs, i, dataformats='NHWC')
                    
            pck = pck / (n_img+1)
            print(f"{n_img=}")
            print(f"{pck=}")
            self.writer.close()

    @staticmethod
    def get_coordinates(batch_kpts_hm):
        # (batch, n_joints, h, w)
        batch, n_joints, h, w = batch_kpts_hm.shape
        top_val, top_idx = torch.topk(batch_kpts_hm.reshape((batch, n_joints, -1)), k=1)

        batch_kpts = torch.zeros((batch, n_joints, 3))
        batch_kpts[..., 0] = (top_idx % w).reshape((batch, n_joints))  # x
        batch_kpts[..., 1] = (top_idx // w).reshape((batch, n_joints))  # y
        batch_kpts[..., 2] = top_val.reshape((batch, n_joints))  # c: score

        return batch_kpts

KeyError: 'increase'

In [9]:
# path = "./checkpoint/MSRB-D-DW-PELEE/1HG-ME-att-c256/2021-12-27/0.981_PCK_47epoch.pt"
path = "./checkpoint/Center_SimDR/1HG-lite/2022-01-06/0.269_PCK_42epoch.pt"
t = TestPreds(checkpoint=path, is_cuda=False, ground_truth=True, exp_name='cs')
t.test(n_img=50, show_hms=False, show_kpts=True)

preparing data...
sample number of testing dataset:  100
done!
loading state dict...
save_dict.keys()=dict_keys(['epoch', 'lr', 'loss', 'mPCK', 'ap', 'state_dict', 'optimizer', 'config'])
done! is_match=True
target_x[target_x>0.1]=tensor([0.10941, 0.12992, 0.15321,  ..., 0.16299, 0.13858, 0.11702])
target_y[target_y>0.1]=tensor([0.10548, 0.12544, 0.14814,  ..., 0.15858, 0.13467, 0.11358])
centermap[centermap>0.1]=tensor([0.10057, 0.10160, 0.10194,  ..., 0.45759, 0.48997, 0.66955])
bbox=tensor([[[129.95987, 118.67818, 151.98604, 117.14246]]])
xywh=[111.92759704589844, 239.44747924804688, 122.93728637695312, 60.902469635009766, 0.20637018978595734]
i=0 ：kpts.shape=(21, 3)
i=0 ：kpts=array([[      141.5,       156.5,      1.2558],
       [        130,       132.5,      3.7996],
       [        139,       111.5,      3.1018],
       [        127,       100.5,      2.2136],
       [      130.5,         123,      1.3536],
       [        130,       111.5,      4.2132],
       [      141.5,   