In [27]:
import torch
import numpy as np
import cv2
from utils.evaluation import evaluate_pck, evaluate_ap, get_coordinates_from_heatmap
from utils.training_kits import stdout_to_tqdm, load_pretrained_state

from models.hourglass_SA import HourglassNet_SA as Network
from data import get_dataset
from config.config import DATASET, config_dict
from utils.visualization_tools import draw_heatmaps, draw_region_maps, draw_point, draw_bbox, draw_text
from time import sleep
from tensorboardX import SummaryWriter
import torchvision.utils as vutils

In [60]:
new_size = config_dict["image_size"][0]
class TestPreds:
    def __init__(self, checkpoint="", is_cuda=True, ground_truth=False):

        print("preparing data...")
        self.dataset, self.test_loader = get_dataset(set_type='test')
        print("done!")

        if is_cuda:
            self.device = torch.device(0)
        else:
            self.device = torch.device('cpu')
        self.model = Network().to(self.device)

        self.ground_truth = ground_truth
        if checkpoint != "":
            print("loading state dict...")
            save_dict = torch.load(checkpoint, map_location=self.device)
            print(f"{save_dict.keys()=}")
            
            state, is_match = load_pretrained_state(self.model.state_dict(),
                                                    save_dict['state_dict'])
            # self.model.load_state_dict(save_dict["state_dict"])
            self.model.load_state_dict(state)
            print(f"done! {is_match=}")
        
        self.writer = SummaryWriter(log_dir='jupyter_log')

    def test(self, n_img=10, show_hms=True, show_kpts=True):
        self.model.eval()
        with torch.no_grad():
            pck = 0
            if n_img == -1:
                n_img = len(self.test_loader)  # -1 整个数据集过一遍
            for i, (img, hm_target, hm_weight, label, bbox) in enumerate(self.test_loader):
                if i > n_img:
                    break
                # print(f"{bbox=}")
                # print(f"{bbox.shape=}")

                start = cv2.getTickCount()  # 计时器起点

                target = hm_target[-1]
                region = target[:, :3]
                hm_kpts = target[:, 3:].to(self.device)
                mask = hm_kpts[:, 0]

                if not self.ground_truth:
                    hms_list = self.model(img.to(self.device))  # [22, 22, 22, 44, 88]
                    hm_kpts = hms_list[-1][:, 3:]
                    region = hms_list[-1][:, :3]
                    
                pck += evaluate_pck(hm_kpts, target[:, 3:], bbox, thr=0.2).item()
                print(f"{pck=}")

                # 画出热图
                if show_hms:
                    draw_region_maps(region)
                    draw_heatmaps(mask)
                    draw_heatmaps(hm_kpts)

                # 画出关键点
                if show_kpts:
                    # batch_xywh = cs_from_region_map(batch_region_maps=region, k=1, thr=0.1)
                    _, _, batch_xywh = evaluate_ap(region, bbox, k=10, conf_thr=0.1)
                    # print(f"{batch_xywh=}")
                    batch_xywh = batch_xywh[0]
                    if batch_xywh is None:
                        print("没有找到目标框")
                        batch_xywh = [[4, 4, 2, 2, 0]]

                    batch_kpts, _ = get_coordinates_from_heatmap(hm_kpts)
                    # batch_kpts = self.get_coordinates(hm_kpts[:, 4:])
                    # print(f"{batch_kpts.shape=}")

                    heatmaps_size = hm_kpts.shape[-1]
                    batch_kpts[..., :2] = batch_kpts[..., :2] * new_size / heatmaps_size  # scale to original size


                    for image, kpts, xywh in zip(img, batch_kpts, batch_xywh):    
                        image = image.permute(1, 2, 0).detach().numpy()
                        m = np.array([0.485, 0.456, 0.406])
                        s = np.array([0.229, 0.224, 0.225])
                        image = image * s + m
                        image *= 255
                        image = image.astype(np.uint8) 
                        
                        kpts = kpts.squeeze(dim=0).detach().cpu().numpy()
                        print(f"{image.shape=}")
                        print(f"{kpts.shape=}")
                        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                        image_drawn = draw_point(img=image.copy(), keypoints=kpts)
                        image_drawn = self.draw_region_bbox(image_drawn, xywh)

                    end = cv2.getTickCount()  # 计时器终点
                    fps = round(cv2.getTickFrequency() / (end - start))
                    text = str(fps) + "fps"
                    img = draw_text(image_drawn, text, (15, 15, 20, 20))
                    # img = img[:,:,::-1]   # BGR to RGB
                    imgs = torch.stack([torch.tensor(image, dtype=torch.uint8),
                                        torch.tensor(image_drawn, dtype=torch.uint8)], dim=0)
                    # img_grid = vutils.make_grid(imgs, normalize=True, scale_each=True, nrow=2)
                    self.writer.add_images('images', imgs, i, dataformats='NHWC')
                    
            pck = pck / (n_img+1)
            print(f"{n_img=}")
            print(f"{pck=}")
            self.writer.close()

    @staticmethod
    def draw_region_bbox(img, xywhc):
        cx, cy, w, h = xywhc[:4]
        x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
        x1 = int(max(0, x1))
        y1 = int(max(0, y1))
        x2 = int(min(x2, new_size))
        y2 = int(min(y2, new_size))
        img = draw_bbox(img, x1, y1, x2, y2)
        return img
        
    @staticmethod
    def get_coordinates(batch_kpts_hm):
        # (batch, n_joints, h, w)
        batch, n_joints, h, w = batch_kpts_hm.shape
        top_val, top_idx = torch.topk(batch_kpts_hm.reshape((batch, n_joints, -1)), k=1)

        batch_kpts = torch.zeros((batch, n_joints, 3))
        batch_kpts[..., 0] = (top_idx % w).reshape((batch, n_joints))  # x
        batch_kpts[..., 1] = (top_idx // w).reshape((batch, n_joints))  # y
        batch_kpts[..., 2] = top_val.reshape((batch, n_joints))  # c: score

        return batch_kpts

In [61]:
path = "./checkpoint/MSRB-D-DW-PELEE/1HG-ME-att-c256/2021-12-27/0.981_PCK_47epoch.pt"
t = TestPreds(checkpoint=path, is_cuda=False, ground_truth=False)
t.test(n_img=20, show_hms=False, show_kpts=True)

preparing data...
sample number of testing dataset:  13024
done!
loading state dict...
save_dict.keys()=dict_keys(['epoch', 'lr', 'loss', 'mPCK', 'ap', 'state_dict', 'optimizer', 'config'])
done! is_match=True
pck=0.959821430966258
image.shape=(256, 256, 3)
kpts.shape=(21, 2)
pck=1.9523809552192688
image.shape=(256, 256, 3)
kpts.shape=(21, 2)
pck=2.9434523843228817
image.shape=(256, 256, 3)
kpts.shape=(21, 2)
pck=3.900297624990344
image.shape=(256, 256, 3)
kpts.shape=(21, 2)
pck=4.858630960807204
image.shape=(256, 256, 3)
kpts.shape=(21, 2)
pck=5.848214294761419
image.shape=(256, 256, 3)
kpts.shape=(21, 2)
pck=6.776785724796355
image.shape=(256, 256, 3)
kpts.shape=(21, 2)
pck=7.775297629646957
image.shape=(256, 256, 3)
kpts.shape=(21, 2)
pck=8.751488107256591
image.shape=(256, 256, 3)
kpts.shape=(21, 2)
pck=9.741071441210806
image.shape=(256, 256, 3)
kpts.shape=(21, 2)
pck=10.66964287403971
image.shape=(256, 256, 3)
kpts.shape=(21, 2)
pck=11.607142876833677
image.shape=(256, 256, 3)
kp