# Example of the prediction of the semi-supervised pose model for art

## Loading of dependencies and definition of auxiliary functions 

In [3]:
!pip install huggingface_hub;
import numpy as np
import torch
from huggingface_hub import hf_hub_url, cached_download

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [4]:


def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)


def box_xyxy_to_cxcywh(x):
    x0, y0, x1, y1 = x.unbind(-1)
    b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
    return torch.stack(b, dim=-1)


def box_xyxy_to_4points(x):
    x0, y0, x1, y1 = x.unbind(-1)
    b = [x0, y0, x1, y1, x0, y1, x1, y0]
    return torch.stack(b, dim=-1)


def box_4points_to_xyxy(x):
    x0, y0, x1, y1, x2, y2, x3, y3 = x.unbind(-1)
    xs = torch.stack([x0, x1, x2, x3], dim=-1)
    ys = torch.stack([y0, y1, y2, y3], dim=-1)

    b = [
        torch.min(xs, dim=-1).values,
        torch.min(ys, dim=-1).values,
        torch.max(xs, dim=-1).values,
        torch.max(ys, dim=-1).values,
    ]
    return torch.stack(b, dim=-1)


def point_to_abs(points, size):

    original_shape = points.shape
    points = points.reshape(-1, original_shape[-1])
    points, meta = points[:, :2], points[:, 2:]
    points = points * torch.as_tensor([size[1], size[0]], device=points.device)

    transformed_points = torch.cat([points, meta], dim=-1)
    return transformed_points.reshape(original_shape)


def point_to_rel(points, size):

    original_shape = points.shape
    points = points.reshape(-1, original_shape[-1])
    points, meta = points[:, :2], points[:, 2:]
    points = points / torch.as_tensor([size[1], size[0]], device=points.device)

    transformed_points = torch.cat([points, meta], dim=-1)
    return transformed_points.reshape(original_shape)


def points_transformation(points, transformation):
    original_shape = points.shape

    # prepare points [N,3]
    points = points.reshape(-1, original_shape[-1])
    points, meta = points[:, :2], points[:, 2:]
    points = torch.cat([points, torch.as_tensor([[1.0]] * points.shape[0], device=points.device)], dim=1)
    points = points.unsqueeze(2)

    # prepare transformation [N,3,3]
    if len(transformation.shape) == 2:
        transformation = torch.unsqueeze(transformation, dim=0).expand(points.shape[0], 3, 3)

    transformed_points = transformation @ points
    transformed_points = torch.cat([torch.squeeze(transformed_points, 2)[:, :2], meta], dim=-1)
    return transformed_points.reshape(original_shape)


def boxes_to_abs(boxes, size):

    original_shape = boxes.shape
    boxes = boxes.reshape(-1, original_shape[-1])
    boxes, meta = boxes[:, :4], boxes[:, 4:]
    boxes = boxes * torch.as_tensor([size[1], size[0], size[1], size[0]], device=boxes.device)

    transformed_boxes = torch.cat([boxes, meta], dim=-1)
    return transformed_boxes.reshape(original_shape)


def boxes_to_rel(boxes, size):

    original_shape = boxes.shape
    boxes = boxes.reshape(-1, original_shape[-1])
    boxes, meta = boxes[:, :4], boxes[:, 4:]
    boxes = boxes / torch.as_tensor([size[1], size[0], size[1], size[0]], device=boxes.device)

    transformed_boxes = torch.cat([boxes, meta], dim=-1)
    return transformed_boxes.reshape(original_shape)


def boxes_transformation(boxes, transformation):
    original_shape = boxes.shape

    # prepare points [N,3]
    points = boxes.reshape(-1, original_shape[-1])
    # should be possible with a single reshape
    points_xyxy, meta = points[:, :4], points[:, 4:]
    # we need to compute all 4 points

    points = box_xyxy_to_4points(points_xyxy)
    points = points.reshape(-1, 2)
    transformed_points = points_transformation(points, transformation)
    transformed_points = transformed_points.reshape(-1, 8)

    transformed_points = box_4points_to_xyxy(transformed_points)

    transformed_points = torch.cat([transformed_points, meta], dim=-1)
    return transformed_points.reshape(original_shape)


def boxes_fit_size(boxes, size):
    h, w = size[0], size[1]

    original_shape = boxes.shape

    max_size = torch.as_tensor([w, h], dtype=torch.float32, device=size.device)
    boxes = torch.min(boxes.reshape(-1, 2, 2), max_size)
    boxes = boxes.clamp(min=0)

    return boxes.reshape(original_shape)


def boxes_scale(boxes, scale, size=None):

    box_cxcywh = box_xyxy_to_cxcywh(boxes)
    scaled_box_wh = box_cxcywh[2:] * scale
    scaled_box = box_cxcywh_to_xyxy(torch.cat([box_cxcywh[:2], scaled_box_wh], dim=0))
    if size is not None:
        scaled_box = boxes_fit_size(scaled_box, size)

    return scaled_box


def boxes_aspect_ratio(boxes, aspect_ratio, size=None):
    box_cxcywh = box_xyxy_to_cxcywh(boxes)
    w, h = box_cxcywh[2], box_cxcywh[3]
    n_w, n_h = w, h
    if w > aspect_ratio * h:
        n_h = w * 1.0 / aspect_ratio
    elif w < aspect_ratio * h:
        n_w = h * aspect_ratio
    scaled_box = box_cxcywh_to_xyxy(torch.stack([box_cxcywh[0], box_cxcywh[1], n_w, n_h], dim=0))
    if size is not None:
        scaled_box = boxes_fit_size(scaled_box, size)
    return scaled_box


## Person detection

In [5]:

def post_process_person_predictions(pred_logits, pred_boxes, targets, threshold=0.1):

    predictions = {"boxes": [], "labels": [], "size": targets["size"], "scores": []}

    batch_size = pred_logits.shape[0]

    label_softmax = torch.softmax(pred_logits, dim=-1)
    top_prediction = label_softmax > threshold
    boxes_pos = top_prediction[..., :-1].nonzero()

    for b in range(batch_size):
        boxes = []
        labels = []
        scores = []
        inv_transformation = torch.linalg.inv(targets["transformation"][b])
        weak_boxes_abs = boxes_to_abs(box_cxcywh_to_xyxy(pred_boxes[b]), size=targets["size"][b])
        boxes_origins_abs = boxes_transformation(weak_boxes_abs, inv_transformation)

        boxes_sample = boxes_pos[boxes_pos[:, 0] == b]

        for box in boxes_sample.unbind(0):
            box_index = box[1]
            box_cls = box[2]
            box_cxcywh = boxes_origins_abs[box_index]
            box_score = label_softmax[b, box_index, box_cls]
            labels.append(box_cls)
            boxes.append(box_cxcywh)
            scores.append(box_score)
        if len(boxes) > 0:
            predictions["boxes"].append(torch.stack(boxes, dim=0))
            predictions["labels"].append(torch.stack(labels, dim=0))
            predictions["scores"].append(torch.stack(scores, dim=0))
        else:
            predictions["boxes"].append(
                torch.zeros(
                    [0, 4],
                    device=label_softmax.device,
                )
            )
            predictions["labels"].append(torch.zeros([0], dtype=torch.int64, device=label_softmax.device))
            predictions["scores"].append(torch.zeros([0], device=label_softmax.device))
    return predictions


In [6]:

REPO_ID = "springsteinm/iart-semi-pose"
FILENAME = "popart_semi_bbox_v1_trace.pt"

person_model = torch.jit.load(cached_download(hf_hub_url(REPO_ID, FILENAME)))



Downloading:   0%|          | 0.00/334M [00:00<?, ?B/s]

In [7]:
import imageio
import matplotlib.pyplot as plt
example_image = "./benjamin-west_mrs-thomas-keyes-and-her-daughter.jpg"

image = imageio.imread(example_image)
plt.imshow(image)

FileNotFoundError: ignored

In [None]:
prediction = person_model(torch.from_numpy(image))

# targets describes the change of the image before it was given into the model, here everything is left on default.
targets = {
    "size": [image.shape[0:2]],
    "origin_size": [image.shape[0:2]],
    "transformation": [torch.tensor([[1.0, 0.0, 0.0000], [0.0000, 1.0, 0.0000], [0.0000, 0.0, 1.0000]])],
}
final_person_prediciton = post_process_person_predictions(prediction[0], prediction[1], targets=targets)

In [None]:
cropped_images = []
for box in final_person_prediciton["boxes"][0]:
  xyxy = box_cxcywh_to_xyxy(box).detach().numpy()
  box_image = image[max(0,int(box[1])):int(box[3]),max(0,int(box[0])):int(box[2]),:]
  fig, axs = plt.subplots(1, 1)
  axs.imshow(box_image)
  cropped_images.append(box_image)