In [2]:
import os
# import base64
import numpy as np
import pandas as pd
# from pycocotools import _mask as coco_mask
# import zlib
import cv2
from ultralytics import YOLO
from ultralytics.nn.autobackend import AutoBackend
import torch
import torch.nn.functional as F
from glob import glob

In [27]:
models = [
    YOLO("/home/ilia_kiselev/Downloads/mix_8_cls.pt"),
    YOLO("/home/ilia_kiselev/Downloads/smooth_tal_light_002.pt"),

]
CONF=0.001

In [4]:
def encode_binary_mask(mask: np.ndarray) -> str:
    """Converts a binary mask into OID challenge encoding ascii text."""
    # check input mask --
    if mask.dtype != bool:
        raise ValueError(
            "encode_binary_mask expects a binary mask, received dtype == %s" %
             mask.dtype)

    mask = np.squeeze(mask)
    if len(mask.shape) != 2:
        raise ValueError(
            "encode_binary_mask expects a 2d mask, received shape == %s" %
            mask.shape)

    # convert input mask to expected COCO API input --
    mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
    mask_to_encode = mask_to_encode.astype(np.uint8)
    mask_to_encode = np.asfortranarray(mask_to_encode)

    # RLE encode mask --
    encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

    # compress and base64 encoding --
    binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
    base64_str = base64.b64encode(binary_str)
    return base64_str

In [5]:
fails = []
def return_default_value_if_fails(default_value):

    def decorator(func):
        def inner(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except Exception as e:
                fails.append((func, (args, kwargs), e))
                return default_value
        return inner

    return decorator

In [6]:
@return_default_value_if_fails(default_value='')
def collect_predictions(result):
    print(result.masks.data.shape)
#     rescaled_mask = F.interpolate(result.masks.data.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False)
#     rescaled_mask = rescaled_mask.squeeze(0)
    masks = result.masks.data.cpu().numpy()
    classes = result.boxes.cls.cpu().tolist()
    confs = result.boxes.conf.cpu().numpy()
    prediction_string = []

    for pred, conf, mask in zip(classes, confs, masks):
        if pred > 0:
            continue
        mask = mask.astype(bool)
        encoded_mask = 'encoded' #encode_binary_mask(mask).decode('utf-8')
        prediction_string.append(f"{int(pred)} {conf} {encoded_mask}")    
    return " ".join(prediction_string)

In [121]:
# @return_default_value_if_fails(default_value='')
def collect_predictions_nms(boxes, masks):
#     rescaled_mask = F.interpolate(result.masks.data.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False)
#     rescaled_mask = rescaled_mask.squeeze(0)
    masks = masks.cpu().numpy()
    classes = boxes[:, 5].cpu().tolist()
    confs = boxes[:, 4].cpu().numpy()
    prediction_string = []

    for pred, conf, mask in zip(classes, confs, masks):
        if pred > 0:
            continue
        mask = mask.astype(bool)
        encoded_mask = 'encoded' #encode_binary_mask(mask).decode('utf-8')
        prediction_string.append(f"{int(pred)} {conf} {encoded_mask}")    
    return " ".join(prediction_string)

In [129]:
glob('/home/ilia_kiselev/Figure_1.png')

['/home/ilia_kiselev/Figure_1.png']

In [143]:
from PIL import Image
import io


ids = []
heights = []
widths = []
prediction_strings = []

  
for image in glob('/media/ilia_kiselev/Datasets/hubmap-hacking-the-human-vasculature/test/*'):
    with Image.open(image) as im:
        in_mem_file = io.BytesIO()
        im.save(in_mem_file, format='JPEG')
        in_mem_file.seek(0)
        im_jpg = Image.open(in_mem_file)
    res = []
    for m_idx, model in enumerate(models):
        for result in model(im_jpg, conf=0.001, iou=0.7, max_det=300, retina_masks=True):
            if result.masks:
                masks = result.masks.data
                boxes = result.boxes.data
                classes = one_hot(boxes[:, 5].long(), num_classes=8) * boxes[:, 4].unsqueeze(1).to(boxes.device)
                index = torch.arange(len(boxes)).to(boxes.device).unsqueeze(1)
                m = torch.ones_like(index).to(boxes.device) * m_idx
                res.append((
                    torch.cat(
                        (boxes[:, :4], classes, boxes[:, 4:6], index, m),
                        dim=1
                    ),
                    masks
                ))
            break
    if res:
        all_boxes = torch.cat([r[0] for r in res])
        all_boxes = all_boxes.transpose(1,0).unsqueeze(0)
        nms_boxes = non_max_suppression(
            all_boxes,
            conf_thres=0.001,
            iou_thres=0.7,
            max_det=300,
            nc = 8
        )[0]
        unique_models = torch.unique(nms_boxes[:, 9])
        split_tensors = {m.item(): nms_boxes[nms_boxes[:, 9] == m] for m in unique_models}
        nms_masks = []
        nms_preds = []
        for m, preds in split_tensors.items():
            masks = res[int(m)][1]
            masks = masks[preds[:, 8].long()]
            nms_masks.append(masks)
            nms_preds.append(preds)
        nms_masks = torch.cat(nms_masks)
        nms_boxes = torch.cat(nms_preds)[:, :6]
        prediction_strings.append(collect_predictions_nms(nms_boxes, nms_masks))
    else:
        prediction_strings.append('')
    ids.append(image.split("/")[-1].split(".")[0])
    h, w = result.orig_shape
    heights.append(h)
    widths.append(w)


0: 1280x1280 166 blood_vessels, 2 glomeruluss, 3 FTEs, 317.5ms
Speed: 6.5ms preprocess, 317.5ms inference, 167.3ms postprocess per image at shape (1, 3, 1280, 1280)

0: 1280x1280 164 blood_vessels, 4 glomeruluss, 324.9ms
Speed: 6.6ms preprocess, 324.9ms inference, 221.7ms postprocess per image at shape (1, 3, 1280, 1280)


In [144]:
nms_boxes

tensor([[ 1.1264e+02,  1.0245e+02,  5.3438e+02,  3.8535e+02,  7.4424e-01,  0.0000e+00],
        [ 9.5412e+01,  1.9949e+02,  3.9584e+02,  6.8522e+02,  6.4378e-01,  0.0000e+00],
        [ 5.3496e+00,  1.0437e+02,  4.0264e+02,  5.3261e+02,  4.2012e-01,  0.0000e+00],
        [ 2.1909e+02,  8.6081e+01,  7.2534e+02,  3.4592e+02,  3.3954e-01,  0.0000e+00],
        [ 1.2531e+02,  2.2616e+02,  4.9748e+02,  7.3807e+02,  2.6287e-01,  0.0000e+00],
        [ 1.0894e+02, -2.5101e+01,  5.5702e+02,  2.5286e+01,  2.2370e-01,  2.0000e+00],
        [-3.1366e+01,  1.1759e+01,  3.7204e+02,  4.5637e+02,  5.1673e-02,  0.0000e+00],
        [ 1.7831e+02,  1.0116e+02,  6.3063e+02,  4.1391e+02,  4.5357e-02,  0.0000e+00],
        [-7.6191e+01, -5.9073e+01,  1.9967e+02,  5.9517e+01,  3.1084e-02,  0.0000e+00],
        [ 1.3333e+01,  1.8701e+01,  1.4062e+02,  1.3034e+02,  2.0997e-02,  0.0000e+00],
        [-1.5578e+02,  1.2646e+01,  2.3369e+02,  4.9348e+02,  1.3970e-02,  0.0000e+00],
        [ 2.1869e+02,  2.0238e+0

In [106]:
preds[:, 8].long().shape

torch.Size([61])

In [80]:
nc = 8
prediction=all_boxes
conf_thres = 0.001
bs = prediction.shape[0]  # batch size
nc = nc or (prediction.shape[1] - 4)  # number of classes
nm = prediction.shape[1] - nc - 4
mi = 4 + nc  # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates

In [85]:
prediction[:, 4:mi].shape

torch.Size([16, 8])

In [76]:
all_boxes.transpose(1,0).shape

torch.Size([16, 339])

In [42]:
from torch.nn.functional import one_hot



In [48]:
all_boxes[:, 5].long()

tensor([0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [50]:
one_hot(all_boxes[:, 5].long(), num_classes=8) * all_boxes[:, 4].unsqueeze(1)

tensor([[0.7442, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.6438, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.4201, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0010, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0010, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0010, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

In [39]:
res[1][0]

tensor([[323.3514, 245.5656, 419.8745,  ...,   0.0000,   0.0000,   1.0000],
        [ 62.2975,  26.9251, 125.6931,  ...,   0.0000,   1.0000,   1.0000],
        [420.1374, 474.7765, 474.3341,  ...,   0.0000,   2.0000,   1.0000],
        ...,
        [ 75.3165, 358.2346, 216.8940,  ...,   0.0000, 165.0000,   1.0000],
        [135.2725,   1.2211, 383.0518,  ...,   1.0000, 166.0000,   1.0000],
        [342.9461, 429.6335, 466.5449,  ...,   0.0000, 167.0000,   1.0000]])

In [14]:
masks = result.masks.data
classes = result.boxes.cls
confs = result.boxes.conf
boxes = result.boxes.data

In [None]:
torch.cat((boxes,torch.arange(len(boxes)).to(boxes.device).unsqueeze(1)), dim=1) 

In [22]:
from ultralytics.yolo.utils.ops import non_max_suppression

In [26]:
non_max_suppression(
    torch.cat((boxes,torch.arange(len(boxes)).to(boxes.device).unsqueeze(1)), dim=1),
    conf_thres=0.001,
    iou_thres=0.7,
    max_det=300
)

171

In [57]:
all_boxes.unsqueeze(0)

tensor([[[323.5099, 243.9022, 421.7449,  ...,   0.0000,   0.0000,   0.0000],
         [245.6263, 442.3592, 300.4283,  ...,   0.0000,   1.0000,   0.0000],
         [203.9942, 318.4895, 397.2891,  ...,   0.0000,   2.0000,   0.0000],
         ...,
         [ 75.3165, 358.2346, 216.8940,  ...,   0.0000, 165.0000,   1.0000],
         [135.2725,   1.2211, 383.0518,  ...,   1.0000, 166.0000,   1.0000],
         [342.9461, 429.6335, 466.5449,  ...,   0.0000, 167.0000,   1.0000]]])