In [1]:
from dataset import PoseDataset
from utils import load_config, load_model
from train import validate
from utils import JointsMSELoss
import torch
from torch.utils.data import DataLoader
from utils import get_keypoints_from_heatmaps, compute_oks

In [35]:
config = load_config('./configs/config_w48_384x288.yaml')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_size = config["dataset"]["preprocess"]["input_size"]
batch_size = config["training"]["batch_size"]

run = '20240930_104432'
epoch = 1
config['model']['weights'] = f'runs/{run}/checkpoint_epoch_{epoch}/weights_epoch_{epoch}.pth'

model = load_model(config['model'])
model = model.to(device)

val_dataset = PoseDataset(config["dataset"], config["dataset"]['val'])

val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=8
    )

In [22]:
model.eval()
avg_oks = 0
sigmas = torch.ones(34) * 0.03
sigmas = sigmas.to(device)
with torch.no_grad():
    for images, targets, gt_keypoints, keypoint_visibility,bbox in val_loader:
        images, targets = images.to(device), targets.to(device)
        gt_keypoints = gt_keypoints.to(device)
        keypoint_visibility = keypoint_visibility.to(device)
        bbox = bbox.to(device)
        outputs = model(images)
        pred_keypoints = get_keypoints_from_heatmaps(
            outputs.detach(), input_size[::-1]
        )

        # area = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1])
        area = torch.ones(bbox.shape[0]) * images.shape[2] * images.shape[3]
        area = area.to(device)
        oks = compute_oks(pred_keypoints, gt_keypoints, keypoint_visibility, sigmas, area)
        avg_oks += oks

print(f'Average OKS: {avg_oks / len(val_loader)}')

Average OKS: 0.7557495832443237


w48-384x288: 0.8221208624383236
w32-256x192: 0.7493196738527176
w48_768x576: 0.8510695039269759
ViT-Pose-B-simple: 0.5001742104266552
ViT-Pose-B-classic: 0.7557495832443237

In [38]:
model.eval()
sigmas = torch.ones(17) * 0.03
sigmas = sigmas.to(device)

# Define OKS thresholds (similar to COCO: 0.50 to 0.95 with a step of 0.05)
oks_thresholds = torch.arange(0.5, 1.0, 0.05).to(device)

# Store precision at each threshold
precisions = torch.zeros(len(oks_thresholds)).to(device)

batch_count = 0

with torch.no_grad():
    for images, targets, gt_keypoints, keypoint_visibility, bbox in val_loader:
        images, targets = images.to(device), targets.to(device)
        gt_keypoints = gt_keypoints.to(device)
        keypoint_visibility = keypoint_visibility.to(device)
        bbox = bbox.to(device)
        
        # Model inference
        outputs = model(images)
        pred_keypoints = get_keypoints_from_heatmaps(
            outputs.detach(), input_size[::-1]
        )
        
        # Compute area (bbox can be used, or image size as in your original code)
        area = torch.ones(bbox.shape[0]) * images.shape[2] * images.shape[3]
        area = area.to(device)
        
        # Compute OKS values for the current batch
        oks = compute_oks(pred_keypoints, gt_keypoints, keypoint_visibility, sigmas, area)

        # For each OKS threshold, check if OKS is above the threshold (correct prediction)
        for i, threshold in enumerate(oks_thresholds):
            precisions[i] += (oks >= threshold).float().mean()
        
        batch_count += 1

precisions /= batch_count

# Compute mean Average Precision (mAP)
map_value = precisions.mean().item()

print(f'mAP: {map_value:.4f}')

mAP: 0.9000


34
w32-256x192: 0.5383
w48-384x288: 0.7000
w48_768x576: 0.7548
ViT-Pose-B-classic: 0.5894

17
w48-384x288: 0.7032