In [2]:
#clicker 
from copy import deepcopy

import cv2
import numpy as np


class Clicker:
    def __init__(
        self,
        gt_mask=None,
        init_clicks=None,
        ignore_label=-1,
        click_indx_offset=0,
    ):
        self.click_indx_offset = click_indx_offset
        if gt_mask is not None:
            self.__gt_mask = gt_mask == 1
            self.not_ignore_mask = gt_mask != ignore_label
        else:
            self.__gt_mask = None

        self.reset_clicks()

        if init_clicks is not None:
            for click in init_clicks:
                self.add_click(click)

    def make_next_click(self, pred_mask):
        assert self.__gt_mask is not None
        click = self._get_next_click(pred_mask)
        self.add_click(click)

    def get_clicks(self, clicks_limit=None):
        return self.clicks_list[:clicks_limit]

    def _get_next_click(self, pred_mask, padding=True):
        fn_mask = self.__gt_mask & ~pred_mask & self.not_ignore_mask
        fp_mask = ~self.__gt_mask & pred_mask & self.not_ignore_mask

        if padding:
            fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
            fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")

        fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
        fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)

        if padding:
            fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
            fp_mask_dt = fp_mask_dt[1:-1, 1:-1]

        fn_mask_dt = fn_mask_dt * self.not_clicked_map
        fp_mask_dt = fp_mask_dt * self.not_clicked_map

        fn_max_dist = np.max(fn_mask_dt)
        fp_max_dist = np.max(fp_mask_dt)

        is_positive = fn_max_dist > fp_max_dist
        if is_positive:
            coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist)  # coords is [y, x]
        else:
            coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist)  # coords is [y, x]

        return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0]))

    def add_click(self, click):
        coords = click.coords

        click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks
        if click.is_positive:
            self.num_pos_clicks += 1
        else:
            self.num_neg_clicks += 1

        self.clicks_list.append(click)
        if self.__gt_mask is not None:
            self.not_clicked_map[coords[0], coords[1]] = False

    def _remove_last_click(self):
        click = self.clicks_list.pop()
        coords = click.coords

        if click.is_positive:
            self.num_pos_clicks -= 1
        else:
            self.num_neg_clicks -= 1

        if self.__gt_mask is not None:
            self.not_clicked_map[coords[0], coords[1]] = True

    def reset_clicks(self):
        if self.__gt_mask is not None:
            self.not_clicked_map = np.ones_like(self.__gt_mask, dtype=bool)

        self.num_pos_clicks = 0
        self.num_neg_clicks = 0

        self.clicks_list = []

    def get_state(self):
        return deepcopy(self.clicks_list)

    def set_state(self, state):
        self.reset_clicks()
        for click in state:
            self.add_click(click)

    def __len__(self):
        return len(self.clicks_list)


class Click:
    def __init__(self, is_positive, coords, indx=None):
        self.is_positive = is_positive
        self.coords = coords
        self.indx = indx

    @property
    def coords_and_indx(self):
        return (*self.coords, self.indx)

    def copy(self, **kwargs):
        self_copy = deepcopy(self)
        for k, v in kwargs.items():
            setattr(self_copy, k, v)
        return self_copy


In [10]:
#datasets
import json
import pickle
import random
from copy import deepcopy
from pathlib import Path

import cv2
import numpy as np
import torch
from torchvision import transforms


class ISDataset(torch.utils.data.dataset.Dataset):
    def __init__(
        self,
        augmentator=None,
        points_sampler=MultiPointSampler(max_num_points=24),
        min_object_area=1000,
        epoch_len=-1,
    ):
        super().__init__()
        self.epoch_len = epoch_len
        self.augmentator = augmentator
        self.min_object_area = min_object_area
        self.points_sampler = points_sampler
        self.to_tensor = transforms.ToTensor()

        self.dataset_samples = None

    def __getitem__(self, index):
        if self.epoch_len > 0:
            index = random.randrange(0, len(self.dataset_samples))

        sample = self.get_sample(index)
        sample = self.augment_sample(sample)
        sample.remove_small_objects(self.min_object_area)

        self.points_sampler.sample_object(sample)
        points = np.array(self.points_sampler.sample_points())
        mask = self.points_sampler.selected_mask

        output = {
            "images": self.to_tensor(sample.image),
            "points": points.astype(np.float32),
            "instances": mask,
        }

        return output

    def augment_sample(self, sample) -> DSample:
        if self.augmentator is None:
            return sample
        sample.augment(self.augmentator)
        return sample

    def get_sample(self, index) -> DSample:
        raise NotImplementedError

    def __len__(self):
        if self.epoch_len > 0:
            return self.epoch_len
        return self.get_samples_number()

    def get_samples_number(self):
        return len(self.dataset_samples)


class TestDataset(ISDataset):
    def __init__(self, images=None, masks=None, **kwargs):
        super().__init__(**kwargs)

        self._images_path = Path(images)
        self._insts_path = Path(masks)

        self.dataset_samples = [x.name for x in sorted(self._images_path.glob("*.*"))]
        self._masks_paths = {x.stem: x for x in self._insts_path.glob("*.*")}

    def get_sample(self, index) -> DSample:
        image_name = self.dataset_samples[index]
        image_path = str(self._images_path / image_name)
        mask_path = str(self._masks_paths[image_name.split(".")[0]])

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        instances_mask = np.max(cv2.imread(mask_path).astype(np.int32), axis=2)
        instances_mask[instances_mask > 0] = 1

        return DSample(
            image,
            instances_mask,
            objects_ids=[1],
            sample_id=index,
            imname=image_name,
        )


class CocoLvisDataset(ISDataset):
    def __init__(
        self,
        dataset_path,
        split="train",
        stuff_prob=0.0,
        allow_list_name=None,
        anno_file="hannotation.pickle",
        **kwargs,
    ):
        super().__init__(**kwargs)
        dataset_path = Path(dataset_path)
        self._split_path = dataset_path / split
        self.split = split
        self._images_path = self._split_path / "images"
        self._masks_path = self._split_path / "masks"
        self.stuff_prob = stuff_prob

        with open(self._split_path / anno_file, "rb") as f:
            self.dataset_samples = sorted(pickle.load(f).items())

        if allow_list_name is not None:
            allow_list_path = self._split_path / allow_list_name
            with open(allow_list_path, "r") as f:
                allow_images_ids = json.load(f)
            allow_images_ids = set(allow_images_ids)

            self.dataset_samples = [
                sample
                for sample in self.dataset_samples
                if sample[0] in allow_images_ids
            ]

    def get_sample(self, index) -> DSample:
        image_id, sample = self.dataset_samples[index]
        image_path = self._images_path / f"{image_id}.jpg"

        image = cv2.imread(str(image_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        packed_masks_path = self._masks_path / f"{image_id}.pickle"
        with open(packed_masks_path, "rb") as f:
            encoded_layers, objs_mapping = pickle.load(f)
        layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in encoded_layers]
        layers = np.stack(layers, axis=2)

        instances_info = deepcopy(sample["hierarchy"])
        for inst_id, inst_info in list(instances_info.items()):
            if inst_info is None:
                inst_info = {"children": [], "parent": None, "node_level": 0}
                instances_info[inst_id] = inst_info
            inst_info["mapping"] = objs_mapping[inst_id]

        if self.stuff_prob > 0 and random.random() < self.stuff_prob:
            for inst_id in range(sample["num_instance_masks"], len(objs_mapping)):
                instances_info[inst_id] = {
                    "mapping": objs_mapping[inst_id],
                    "parent": None,
                    "children": [],
                }
        else:
            for inst_id in range(sample["num_instance_masks"], len(objs_mapping)):
                layer_indx, mask_id = objs_mapping[inst_id]
                layers[:, :, layer_indx][layers[:, :, layer_indx] == mask_id] = 0

        return DSample(image, layers, objects=instances_info)


In [6]:
# misc
from functools import lru_cache

import cv2
import numpy as np
import torch


def get_dims_with_exclusion(dim, exclude=None):
    dims = list(range(dim))
    if exclude is not None:
        dims.remove(exclude)

    return dims


def save_checkpoint(net, checkpoints_path, epoch=None):
    if epoch is None:
        checkpoint_name = "last_checkpoint.pt"
    else:
        checkpoint_name = f"{epoch:03d}.pt"

    if not checkpoints_path.exists():
        checkpoints_path.mkdir(parents=True)

    checkpoint_path = checkpoints_path / checkpoint_name
    print(f"Save checkpoint to {str(checkpoint_path)}")

    torch.save({"state_dict": net.state_dict()}, str(checkpoint_path))


def get_bbox_from_mask(mask):
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    return rmin, rmax, cmin, cmax


def expand_bbox(bbox, expand_ratio, min_crop_size=None):
    rmin, rmax, cmin, cmax = bbox
    rcenter = 0.5 * (rmin + rmax)
    ccenter = 0.5 * (cmin + cmax)
    height = expand_ratio * (rmax - rmin + 1)
    width = expand_ratio * (cmax - cmin + 1)
    if min_crop_size is not None:
        height = max(height, min_crop_size)
        width = max(width, min_crop_size)

    rmin = int(round(rcenter - 0.5 * height))
    rmax = int(round(rcenter + 0.5 * height))
    cmin = int(round(ccenter - 0.5 * width))
    cmax = int(round(ccenter + 0.5 * width))

    return rmin, rmax, cmin, cmax


def clamp_bbox(bbox, rmin, rmax, cmin, cmax):
    return (
        max(rmin, bbox[0]),
        min(rmax, bbox[1]),
        max(cmin, bbox[2]),
        min(cmax, bbox[3]),
    )


def get_bbox_iou(b1, b2):
    h_iou = get_segments_iou(b1[:2], b2[:2])
    w_iou = get_segments_iou(b1[2:4], b2[2:4])
    return h_iou * w_iou


def get_segments_iou(s1, s2):
    a, b = s1
    c, d = s2
    intersection = max(0, min(b, d) - max(a, c) + 1)
    union = max(1e-6, max(b, d) - min(a, c) + 1)
    return intersection / union


def get_labels_with_sizes(x):
    obj_sizes = np.bincount(x.flatten())
    labels = np.nonzero(obj_sizes)[0].tolist()
    labels = [x for x in labels if x != 0]
    return labels, obj_sizes[labels].tolist()


@lru_cache(maxsize=16)
def get_palette(num_cls):
    palette = np.zeros(3 * num_cls, dtype=np.int32)

    for j in range(0, num_cls):
        lab = j
        i = 0

        while lab > 0:
            palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
            palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
            palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
            i = i + 1
            lab >>= 3

    return palette.reshape((-1, 3))


def draw_probmap(x):
    return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_HOT)


def draw_points(image, points, color, radius=3):
    image = image.copy()
    for p in points:
        if p[0] < 0:
            continue
        if len(p) == 3:
            pradius = {0: 8, 1: 6, 2: 4}[p[2]] if p[2] < 3 else 2
        else:
            pradius = radius
        image = cv2.circle(image, (int(p[1]), int(p[0])), pradius, color, -1)

    return image


def draw_with_blend_and_clicks(
    img,
    mask=None,
    alpha=0.6,
    clicks_list=None,
    pos_color=(0, 255, 0),
    neg_color=(255, 0, 0),
    radius=4,
):
    result = img.copy()

    if mask is not None:
        palette = get_palette(np.max(mask) + 1)
        rgb_mask = palette[mask.astype(np.uint8)]

        mask_region = (mask > 0).astype(np.uint8)
        result = (
            result * (1 - mask_region[:, :, np.newaxis])
            + (1 - alpha) * mask_region[:, :, np.newaxis] * result
            + alpha * rgb_mask
        )
        result = result.astype(np.uint8)

        # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8)

    if clicks_list is not None and len(clicks_list) > 0:
        pos_points = [click.coords for click in clicks_list if click.is_positive]
        neg_points = [click.coords for click in clicks_list if not click.is_positive]

        result = draw_points(result, pos_points, pos_color, radius=radius)
        result = draw_points(result, neg_points, neg_color, radius=radius)

    return result


In [8]:
# sample 
from copy import deepcopy

import numpy as np


class DSample:
    def __init__(
        self,
        image,
        encoded_masks,
        objects=None,
        objects_ids=None,
        ignore_ids=None,
        sample_id=None,
        imname=None,
    ):
        self.image = image
        self.sample_id = sample_id
        self.imname = imname
        if len(encoded_masks.shape) == 2:
            encoded_masks = encoded_masks[:, :, np.newaxis]
        self._encoded_masks = encoded_masks
        self._ignored_regions = []

        if objects_ids is not None:
            if not objects_ids or not isinstance(objects_ids[0], tuple):
                assert encoded_masks.shape[2] == 1
                objects_ids = [(0, obj_id) for obj_id in objects_ids]

            self._objects = {}
            for indx, obj_mapping in enumerate(objects_ids):
                self._objects[indx] = {
                    "parent": None,
                    "mapping": obj_mapping,
                    "children": [],
                }

            if ignore_ids:
                if isinstance(ignore_ids[0], tuple):
                    self._ignored_regions = ignore_ids
                else:
                    self._ignored_regions = [(0, region_id) for region_id in ignore_ids]
        else:
            self._objects = deepcopy(objects)

        self._augmented = False
        self._soft_mask_aug = None
        self._original_data = self.image, self._encoded_masks, deepcopy(self._objects)

    def augment(self, augmentator):
        self.reset_augmentation()
        aug_output = augmentator(image=self.image, mask=self._encoded_masks)
        self.image = aug_output["image"]
        self._encoded_masks = aug_output["mask"]
        self._compute_objects_areas()
        self.remove_small_objects(min_area=1)
        self._augmented = True

    def reset_augmentation(self):
        if not self._augmented:
            return
        orig_image, orig_masks, orig_objects = self._original_data
        self.image = orig_image
        self._encoded_masks = orig_masks
        self._objects = deepcopy(orig_objects)
        self._augmented = False
        self._soft_mask_aug = None

    def remove_small_objects(self, min_area):
        if self._objects and not "area" in list(self._objects.values())[0]:
            self._compute_objects_areas()

        for obj_id, obj_info in list(self._objects.items()):
            if obj_info["area"] < min_area:
                self._remove_object(obj_id)

    def get_object_mask(self, obj_id):
        layer_indx, mask_id = self._objects[obj_id]["mapping"]
        obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32)
        if self._ignored_regions:
            for layer_indx, mask_id in self._ignored_regions:
                ignore_mask = self._encoded_masks[:, :, layer_indx] == mask_id
                obj_mask[ignore_mask] = -1

        return obj_mask

    def get_soft_object_mask(self, obj_id):
        assert self._soft_mask_aug is not None
        original_encoded_masks = self._original_data[1]
        layer_indx, mask_id = self._objects[obj_id]["mapping"]
        obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(
            np.float32,
        )
        obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)[
            "image"
        ]
        return np.clip(obj_mask, 0, 1)

    def get_background_mask(self):
        return np.max(self._encoded_masks, axis=2) == 0

    @property
    def objects_ids(self):
        return list(self._objects.keys())

    @property
    def gt_mask(self):
        assert len(self._objects) == 1
        return self.get_object_mask(self.objects_ids[0])

    @property
    def root_objects(self):
        return [
            obj_id
            for obj_id, obj_info in self._objects.items()
            if obj_info["parent"] is None
        ]

    def _compute_objects_areas(self):
        inverse_index = {
            node["mapping"]: node_id for node_id, node in self._objects.items()
        }
        ignored_regions_keys = set(self._ignored_regions)

        for layer_indx in range(self._encoded_masks.shape[2]):
            objects_ids, objects_areas = get_labels_with_sizes(
                self._encoded_masks[:, :, layer_indx],
            )
            for obj_id, obj_area in zip(objects_ids, objects_areas):
                inv_key = (layer_indx, obj_id)
                if inv_key in ignored_regions_keys:
                    continue
                try:
                    self._objects[inverse_index[inv_key]]["area"] = obj_area
                    del inverse_index[inv_key]
                except KeyError:
                    layer = self._encoded_masks[:, :, layer_indx]
                    layer[layer == obj_id] = 0
                    self._encoded_masks[:, :, layer_indx] = layer

        for obj_id in inverse_index.values():
            self._objects[obj_id]["area"] = 0

    def _remove_object(self, obj_id):
        obj_info = self._objects[obj_id]
        obj_parent = obj_info["parent"]
        for child_id in obj_info["children"]:
            self._objects[child_id]["parent"] = obj_parent

        if obj_parent is not None:
            parent_children = self._objects[obj_parent]["children"]
            parent_children = [x for x in parent_children if x != obj_id]
            self._objects[obj_parent]["children"] = (
                parent_children + obj_info["children"]
            )

        del self._objects[obj_id]

    def __len__(self):
        return len(self._objects)


In [9]:
# poins sampler
import math
import random
from functools import lru_cache

import cv2
import numpy as np


class MultiPointSampler:
    def __init__(
        self,
        max_num_points,
        prob_gamma=0.7,
        expand_ratio=0.1,
        positive_erode_prob=0.9,
        positive_erode_iters=3,
        negative_bg_prob=0.1,
        negative_other_prob=0.4,
        negative_border_prob=0.5,
        merge_objects_prob=0.0,
        max_num_merged_objects=2,
        use_hierarchy=False,
        soft_targets=False,
        first_click_center=False,
        only_one_first_click=False,
        sfc_inner_k=1.7,
        sfc_full_inner_prob=0.0,
    ):
        self._selected_mask = None
        self._selected_masks = None
        self.max_num_points = max_num_points
        self.expand_ratio = expand_ratio
        self.positive_erode_prob = positive_erode_prob
        self.positive_erode_iters = positive_erode_iters
        self.merge_objects_prob = merge_objects_prob
        self.use_hierarchy = use_hierarchy
        self.soft_targets = soft_targets
        self.first_click_center = first_click_center
        self.only_one_first_click = only_one_first_click
        self.sfc_inner_k = sfc_inner_k
        self.sfc_full_inner_prob = sfc_full_inner_prob

        if max_num_merged_objects == -1:
            max_num_merged_objects = max_num_points
        self.max_num_merged_objects = max_num_merged_objects

        self.neg_strategies = ["bg", "other", "border"]
        self.neg_strategies_prob = [
            negative_bg_prob,
            negative_other_prob,
            negative_border_prob,
        ]
        assert math.isclose(sum(self.neg_strategies_prob), 1.0)

        self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma)
        self._neg_probs = generate_probs(max_num_points + 1, gamma=prob_gamma)
        self._neg_masks = None

    @property
    def selected_mask(self):
        assert self._selected_mask is not None
        return self._selected_mask

    @selected_mask.setter
    def selected_mask(self, mask):
        self._selected_mask = mask[np.newaxis, :].astype(np.float32)

    def sample_object(self, sample: DSample):
        if len(sample) == 0:
            bg_mask = sample.get_background_mask()
            self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32)
            self._selected_masks = [[]]
            self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies}
            self._neg_masks["required"] = []
            return

        gt_mask, pos_masks, neg_masks = self._sample_mask(sample)
        binary_gt_mask = gt_mask > 0.5 if self.soft_targets else gt_mask > 0

        self.selected_mask = gt_mask
        self._selected_masks = pos_masks

        neg_mask_bg = ~binary_gt_mask
        neg_mask_border = self._get_border_mask(binary_gt_mask)
        if len(sample) <= len(self._selected_masks):
            neg_mask_other = neg_mask_bg
        else:
            neg_mask_other = ~(sample.get_background_mask() | binary_gt_mask)

        self._neg_masks = {
            "bg": neg_mask_bg,
            "other": neg_mask_other,
            "border": neg_mask_border,
            "required": neg_masks,
        }

    def _sample_mask(self, sample: DSample):
        root_obj_ids = sample.root_objects

        if len(root_obj_ids) > 1 and random.random() < self.merge_objects_prob:
            max_selected_objects = min(len(root_obj_ids), self.max_num_merged_objects)
            num_selected_objects = np.random.randint(2, max_selected_objects + 1)
            random_ids = random.sample(root_obj_ids, num_selected_objects)
        else:
            random_ids = [random.choice(root_obj_ids)]

        gt_mask = None
        pos_segments = []
        neg_segments = []
        for obj_id in random_ids:
            (
                obj_gt_mask,
                obj_pos_segments,
                obj_neg_segments,
            ) = self._sample_from_masks_layer(obj_id, sample)
            if gt_mask is None:
                gt_mask = obj_gt_mask
            else:
                gt_mask = np.maximum(gt_mask, obj_gt_mask)

            pos_segments.extend(obj_pos_segments)
            neg_segments.extend(obj_neg_segments)

        pos_masks = [self._positive_erode(x) for x in pos_segments]
        neg_masks = [self._positive_erode(x) for x in neg_segments]

        return gt_mask, pos_masks, neg_masks

    def _sample_from_masks_layer(self, obj_id, sample: DSample):
        objs_tree = sample._objects

        if not self.use_hierarchy:
            node_mask = sample.get_object_mask(obj_id)
            gt_mask = (
                sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask
            )
            return gt_mask, [node_mask], []

        def _select_node(node_id):
            node_info = objs_tree[node_id]
            if not node_info["children"] or random.random() < 0.5:
                return node_id
            return _select_node(random.choice(node_info["children"]))

        selected_node = _select_node(obj_id)
        node_info = objs_tree[selected_node]
        node_mask = sample.get_object_mask(selected_node)
        gt_mask = (
            sample.get_soft_object_mask(selected_node)
            if self.soft_targets
            else node_mask
        )
        pos_mask = node_mask.copy()

        negative_segments = []
        if node_info["parent"] is not None and node_info["parent"] in objs_tree:
            parent_mask = sample.get_object_mask(node_info["parent"])
            negative_segments.append(parent_mask & ~node_mask)

        for child_id in node_info["children"]:
            if objs_tree[child_id]["area"] / node_info["area"] < 0.10:
                child_mask = sample.get_object_mask(child_id)
                pos_mask = pos_mask & ~child_mask

        if node_info["children"]:
            max_disabled_children = min(len(node_info["children"]), 3)
            num_disabled_children = np.random.randint(0, max_disabled_children + 1)
            disabled_children = random.sample(
                node_info["children"],
                num_disabled_children,
            )

            for child_id in disabled_children:
                child_mask = sample.get_object_mask(child_id)
                pos_mask = pos_mask & ~child_mask
                if self.soft_targets:
                    soft_child_mask = sample.get_soft_object_mask(child_id)
                    gt_mask = np.minimum(gt_mask, 1.0 - soft_child_mask)
                else:
                    gt_mask = gt_mask & ~child_mask
                negative_segments.append(child_mask)

        return gt_mask, [pos_mask], negative_segments

    def sample_points(self):
        assert self._selected_mask is not None
        pos_points = self._multi_mask_sample_points(
            self._selected_masks,
            is_negative=[False] * len(self._selected_masks),
            with_first_click=self.first_click_center,
        )

        neg_strategy = [
            (self._neg_masks[k], prob)
            for k, prob in zip(self.neg_strategies, self.neg_strategies_prob)
        ]
        neg_masks = self._neg_masks["required"] + [neg_strategy]
        neg_points = self._multi_mask_sample_points(
            neg_masks,
            is_negative=[False] * len(self._neg_masks["required"]) + [True],
        )

        return pos_points + neg_points

    def _multi_mask_sample_points(
        self,
        selected_masks,
        is_negative,
        with_first_click=False,
    ):
        selected_masks = selected_masks[: self.max_num_points]

        each_obj_points = [
            self._sample_points(
                mask,
                is_negative=is_negative[i],
                with_first_click=with_first_click,
            )
            for i, mask in enumerate(selected_masks)
        ]
        each_obj_points = [x for x in each_obj_points if len(x) > 0]

        points = []
        if len(each_obj_points) == 1:
            points = each_obj_points[0]
        elif len(each_obj_points) > 1:
            if self.only_one_first_click:
                each_obj_points = each_obj_points[:1]

            points = [obj_points[0] for obj_points in each_obj_points]

            aggregated_masks_with_prob = []
            for x in selected_masks:
                if (
                    isinstance(x, (list, tuple))
                    and x
                    and isinstance(x[0], (list, tuple))
                ):
                    for t, prob in x:
                        aggregated_masks_with_prob.append(
                            (t, prob / len(selected_masks)),
                        )
                else:
                    aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks)))

            other_points_union = self._sample_points(
                aggregated_masks_with_prob,
                is_negative=True,
            )
            if len(other_points_union) + len(points) <= self.max_num_points:
                points.extend(other_points_union)
            else:
                points.extend(
                    random.sample(
                        other_points_union,
                        self.max_num_points - len(points),
                    ),
                )

        if len(points) < self.max_num_points:
            points.extend([(-1, -1, -1)] * (self.max_num_points - len(points)))

        return points

    def _sample_points(self, mask, is_negative=False, with_first_click=False):
        if is_negative:
            num_points = np.random.choice(
                np.arange(self.max_num_points + 1),
                p=self._neg_probs,
            )
        else:
            num_points = 1 + np.random.choice(
                np.arange(self.max_num_points),
                p=self._pos_probs,
            )

        indices_probs = None
        if isinstance(mask, (list, tuple)):
            indices_probs = [x[1] for x in mask]
            indices = [(np.argwhere(x), prob) for x, prob in mask]
            if indices_probs:
                assert math.isclose(sum(indices_probs), 1.0)
        else:
            indices = np.argwhere(mask)

        points = []
        for j in range(num_points):
            first_click = with_first_click and j == 0 and indices_probs is None

            if first_click:
                point_indices = get_point_candidates(
                    mask,
                    k=self.sfc_inner_k,
                    full_prob=self.sfc_full_inner_prob,
                )
            elif indices_probs:
                point_indices_indx = np.random.choice(
                    np.arange(len(indices)),
                    p=indices_probs,
                )
                point_indices = indices[point_indices_indx][0]
            else:
                point_indices = indices

            num_indices = len(point_indices)
            if num_indices > 0:
                point_indx = 0 if first_click else 100
                click = point_indices[np.random.randint(0, num_indices)].tolist() + [
                    point_indx,
                ]
                points.append(click)

        return points

    def _positive_erode(self, mask):
        if random.random() > self.positive_erode_prob:
            return mask

        kernel = np.ones((3, 3), np.uint8)
        eroded_mask = cv2.erode(
            mask.astype(np.uint8),
            kernel,
            iterations=self.positive_erode_iters,
        ).astype(bool)

        if eroded_mask.sum() > 10:
            return eroded_mask
        return mask

    def _get_border_mask(self, mask):
        expand_r = int(np.ceil(self.expand_ratio * np.sqrt(mask.sum())))
        kernel = np.ones((3, 3), np.uint8)
        expanded_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=expand_r)
        expanded_mask[mask.astype(bool)] = 0
        return expanded_mask


@lru_cache(maxsize=None)
def generate_probs(max_num_points, gamma):
    probs = []
    last_value = 1
    for _ in range(max_num_points):
        probs.append(last_value)
        last_value *= gamma

    probs = np.array(probs)
    probs /= probs.sum()

    return probs


def get_point_candidates(obj_mask, k=1.7, full_prob=0.0):
    if full_prob > 0 and random.random() < full_prob:
        return obj_mask

    padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), "constant")

    dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1]
    if k > 0:
        inner_mask = dt > dt.max() / k
        return np.argwhere(inner_mask)

    prob_map = dt.flatten()
    prob_map /= max(prob_map.sum(), 1e-6)
    click_indx = np.random.choice(len(prob_map), p=prob_map)
    click_coords = np.unravel_index(click_indx, dt.shape)
    return np.array([click_coords])


In [14]:
import random
from copy import deepcopy
from pathlib import Path
from types import SimpleNamespace

import albumentations as A
import cv2
import numpy as np
import torch
from albumentations import Compose, DualTransform, PadIfNeeded, RandomCrop
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large
from tqdm import tqdm


class ISModel(nn.Module):
    # Your model should not have required parameters for init
    def __init__(
        self,
        pretrained=False,
    ):
        super().__init__()

        self.normalization = BatchImageNormalize(
            [0.485, 0.456, 0.406],
            [0.229, 0.224, 0.225],
        )

        # Positive, Negative and Previous Mask
        self.coord_feature_ch = 3
        self.dist_maps = DistMaps(
            norm_radius=5,
            spatial_scale=1.0,
            use_disks=True,
        )

        weights = MobileNet_V3_Large_Weights.IMAGENET1K_V1 if pretrained else None
        self.feature_extractor = deeplabv3_mobilenet_v3_large(
            num_classes=1,
            weights_backbone=weights,
        )

        # Add user clicks and mask on input
        old_conv = self.feature_extractor.backbone["0"][0]

        new_conv = nn.Sequential(
            nn.Conv2d(
                old_conv.in_channels + self.coord_feature_ch,
                old_conv.out_channels,
                kernel_size=1,
                bias=False,
            ),
            nn.BatchNorm2d(old_conv.out_channels),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(
                old_conv.out_channels,
                old_conv.out_channels,
                kernel_size=(3, 3),
                stride=(2, 2),
                padding=(1, 1),
                bias=False,
            ),
        )

        self.feature_extractor.backbone["0"][0] = new_conv

        # Remove Dropout layer
        self.feature_extractor.classifier[0].project[3] = nn.Identity()

        # Will be used in testing
        self.pred_thr = 0.5

    def forward(self, image, points):
        image, prev_mask = self.prepare_input(image)
        coord_features = self.get_coord_features(image, prev_mask, points)
        outputs = self.backbone_forward(image, coord_features)
        return outputs

    def prepare_input(self, image):
        prev_mask = image[:, 3:, :, :]
        image = image[:, :3, :, :]
        image = self.normalization(image)
        return image, prev_mask

    def backbone_forward(self, image, coord_features):
        net_input = torch.cat((image, coord_features), dim=1)
        net_outputs = self.feature_extractor(net_input)["out"]
        return {"instances": net_outputs}

    def get_coord_features(self, image, prev_mask, points):
        coord_features = self.dist_maps(image, points)
        coord_features = torch.cat((prev_mask, coord_features), dim=1)
        return coord_features

    def restore_from_checkpoint(self, checkpoint_path, device):
        checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device)
        self.load_state_dict(checkpoint["state_dict"])
        return self


class Predictor:
    def __init__(self, model, device):
        self.original_image = None
        self.device = device
        self.prev_prediction = None
        self.net = model
        self.to_tensor = transforms.ToTensor()

    def set_input_image(self, image):
        image_nd = self.to_tensor(image)
        self.original_image = image_nd.to(self.device)
        if len(self.original_image.shape) == 3:
            self.original_image = self.original_image.unsqueeze(0)
        self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :])

    def get_prediction(self, clicker, prev_mask=None):
        clicks_list = deepcopy(clicker.get_clicks())

        input_image = self.original_image
        if prev_mask is None:
            prev_mask = self.prev_prediction

        input_image = torch.cat((input_image, prev_mask), dim=1)

        prev_size = input_image.shape[2:]
        input_image = torch.nn.functional.interpolate(
            input_image,
            (400, 400),
            mode="bilinear",
            align_corners=True,
        )

        # Scale clicks too
        for click in clicks_list:
            click.coords = (
                click.coords[0] / (prev_size[0] / 400.0),
                click.coords[1] / (prev_size[1] / 400.0),
            )

        prediction = self._get_prediction(input_image, [clicks_list])
        prediction = torch.nn.functional.interpolate(
            prediction,
            prev_size,
            mode="bilinear",
            align_corners=True,
        )
        prediction = torch.sigmoid(prediction)

        self.prev_prediction = prediction
        return prediction.cpu().numpy()[0, 0]

    def _get_prediction(self, image_nd, clicks_lists):
        points_nd = self.get_points_nd(clicks_lists)
        return self.net(image_nd, points_nd)["instances"]

    def get_points_nd(self, clicks_lists):
        total_clicks = []
        num_pos_clicks = [
            sum(click.is_positive for click in clicks_list)
            for clicks_list in clicks_lists
        ]
        num_neg_clicks = [
            len(clicks_list) - num_pos
            for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
        ]
        num_max_points = max(num_pos_clicks + num_neg_clicks)
        num_max_points = max(1, num_max_points)

        for clicks_list in clicks_lists:
            neg_clicks, pos_clicks = targets = [], []
            for click in clicks_list:
                targets[click.is_positive].append(click.coords_and_indx)

            pos_padding = num_max_points - len(pos_clicks)
            pos_clicks = pos_clicks + pos_padding * [(-1, -1, -1)]

            neg_padding = num_max_points - len(neg_clicks)
            neg_clicks = neg_clicks + neg_padding * [(-1, -1, -1)]

            total_clicks.append(pos_clicks + neg_clicks)

        return torch.tensor(total_clicks, device=self.device)


class DistMaps(torch.nn.Module):
    def __init__(self, norm_radius=5, spatial_scale=1.0, use_disks=True):
        super().__init__()
        self.spatial_scale = spatial_scale
        self.norm_radius = norm_radius
        self.use_disks = use_disks

    def get_coord_features(self, points, rows, cols):
        num_points = points.shape[1] // 2
        points = points.view(-1, points.size(2))
        points, _ = torch.split(points, [2, 1], dim=1)

        invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0
        row_array = torch.arange(
            start=0,
            end=rows,
            step=1,
            dtype=torch.float32,
            device=points.device,
        )
        col_array = torch.arange(
            start=0,
            end=cols,
            step=1,
            dtype=torch.float32,
            device=points.device,
        )

        coord_rows, coord_cols = torch.meshgrid(
            row_array,
            col_array,
            indexing="ij",
        )
        coords = torch.stack((coord_rows, coord_cols), dim=0)
        coords = coords.unsqueeze(0).repeat(points.size(0), 1, 1, 1)

        add_xy = points * self.spatial_scale
        add_xy = add_xy.view(points.size(0), points.size(1), 1, 1)
        coords.add_(-add_xy)
        if not self.use_disks:
            coords.div_(self.norm_radius * self.spatial_scale)
        coords.mul_(coords)

        coords[:, 0] += coords[:, 1]
        coords = coords[:, :1]

        coords[invalid_points, :, :, :] = 1e6

        coords = coords.view(-1, num_points, 1, rows, cols)
        coords = coords.min(dim=1)[0]  # -> (bs * num_masks * 2) x 1 x h x w
        coords = coords.view(-1, 2, rows, cols)

        if self.use_disks:
            coords = (coords <= (self.norm_radius * self.spatial_scale) ** 2).float()
        else:
            coords.sqrt_().mul_(2).tanh_()

        return coords

    def forward(self, x, coords):
        return self.get_coord_features(coords, x.shape[2], x.shape[3])


class BatchImageNormalize:
    def __init__(self, mean, std, dtype=torch.float):
        self.mean = torch.as_tensor(mean, dtype=dtype)[None, :, None, None]
        self.std = torch.as_tensor(std, dtype=dtype)[None, :, None, None]

    def __call__(self, tensor):
        tensor = tensor.clone()

        tensor.sub_(self.mean.to(tensor.device))
        tensor.div_(self.std.to(tensor.device))
        return tensor


def get_next_points(pred, gt, points, click_indx, pred_thresh=0.5):
    """Simulate click to the area with largest error"""
    assert click_indx > 0
    pred = pred.cpu().numpy()[:, 0, :, :]
    gt = gt.cpu().numpy()[:, 0, :, :] > 0.5

    fn_mask = gt & (pred < pred_thresh)
    fp_mask = ~gt & (pred > pred_thresh)

    fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8)
    fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8)
    num_points = points.size(1) // 2
    points = points.clone()

    for bindx in range(fn_mask.shape[0]):
        fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1]
        fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1]

        fn_max_dist = np.max(fn_mask_dt)
        fp_max_dist = np.max(fp_mask_dt)

        is_positive = fn_max_dist > fp_max_dist
        dt = fn_mask_dt if is_positive else fp_mask_dt
        inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0
        indices = np.argwhere(inner_mask)
        if len(indices) > 0:
            coords = indices[np.random.randint(0, len(indices))]
            if is_positive:
                points[bindx, num_points - click_indx, 0] = float(coords[0])
                points[bindx, num_points - click_indx, 1] = float(coords[1])
                points[bindx, num_points - click_indx, 2] = float(click_indx)
            else:
                points[bindx, 2 * num_points - click_indx, 0] = float(coords[0])
                points[bindx, 2 * num_points - click_indx, 1] = float(coords[1])
                points[bindx, 2 * num_points - click_indx, 2] = float(click_indx)

    return points


class ISTrainer:
    def __init__(
        self,
        model,
        cfg,
        instance_loss,
        trainset,
        valset,
        image_dump_interval=10,
        checkpoint_interval=10,
        max_initial_points=0,
        max_interactive_clicks=0,
    ):
        self.cfg = cfg
        self.max_initial_points = max_initial_points
        self.instance_loss = instance_loss
        self.max_interactive_clicks = max_interactive_clicks
        self.checkpoint_interval = checkpoint_interval
        self.image_dump_interval = image_dump_interval
        self.trainset = trainset
        self.valset = valset

        train_size = trainset.get_samples_number()
        print(f"Dataset of {train_size} samples was loaded for training.")
        val_size = valset.get_samples_number()
        print(f"Dataset of {val_size} samples was loaded for validation.")

        self.train_data = DataLoader(
            trainset,
            batch_size=cfg.batch_size,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
            num_workers=4,
        )

        self.val_data = DataLoader(
            valset,
            batch_size=cfg.val_batch_size,
            shuffle=False,
            drop_last=True,
            pin_memory=True,
            num_workers=4,
        )

        self.device = cfg.device
        self.net = model.to(self.device)
        self.optim = torch.optim.AdamW(self.net.parameters(), lr=3e-4)

    def run(self, num_epochs, validation=True):
        print(f"Total Epochs: {num_epochs}")
        for epoch in range(num_epochs):
            self.training(epoch)
            if validation:
                self.validation(epoch)

    def training(self, epoch):
        tbar = tqdm(self.train_data, ncols=100)
        self.net.train()
        train_loss = 0.0

        for i, batch_data in enumerate(tbar):
            global_step = epoch * len(self.train_data) + i

            loss, splitted_batch_data, outputs = self.batch_forward(batch_data)

            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

            train_loss += loss.item()

            if (
                self.image_dump_interval > 0
                and global_step % self.image_dump_interval == 0
            ):
                self.save_visualization(
                    splitted_batch_data,
                    outputs,
                    global_step,
                    prefix="train",
                )

            tbar.set_description(f"Epoch {epoch}, training loss {train_loss/(i+1):.4f}")

        save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, epoch=None)

        if epoch % self.checkpoint_interval == 0:
            save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, epoch=epoch)

    def validation(self, epoch):
        tbar = tqdm(self.val_data, ncols=100)
        self.net.eval()
        val_loss = 0

        for i, batch_data in enumerate(tbar):
            loss, _, _ = self.batch_forward(
                batch_data,
                validation=True,
            )

            val_loss += loss.item()

            tbar.set_description(
                f"Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}",
            )

    def batch_forward(self, batch_data, validation=False):
        with torch.set_grad_enabled(not validation):
            batch_data = {k: v.to(self.device) for k, v in batch_data.items()}
            image, gt_mask, points = (
                batch_data["images"],
                batch_data["instances"],
                batch_data["points"],
            )

            prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :]

            # Make interactive steps
            with torch.no_grad():
                num_iters = random.randint(0, self.max_interactive_clicks)

                for click_indx in range(num_iters):
                    if not validation:
                        self.net.eval()

                    net_input = torch.cat((image, prev_output), dim=1)
                    prev_output = self.net(net_input, points)["instances"]
                    prev_output = torch.sigmoid(prev_output)

                    points = get_next_points(
                        prev_output,
                        gt_mask,
                        points,
                        click_indx + 1,
                    )

                    if not validation:
                        self.net.train()

            batch_data["points"] = points

            net_input = torch.cat((image, prev_output), dim=1)
            output = self.net(net_input, points)

            loss = self.instance_loss(output["instances"], batch_data["instances"])
            loss = torch.mean(loss)

        return loss, batch_data, output

    def save_visualization(
        self,
        splitted_batch_data,
        outputs,
        global_step,
        prefix,
    ):
        output_images_path = self.cfg.VIS_PATH / prefix

        if not output_images_path.exists():
            output_images_path.mkdir(parents=True)
        image_name_prefix = f"{global_step:06d}"

        def _save_image(suffix, image):
            cv2.imwrite(
                str(output_images_path / f"{image_name_prefix}_{suffix}.jpg"),
                image,
                [cv2.IMWRITE_JPEG_QUALITY, 85],
            )

        images = splitted_batch_data["images"]
        points = splitted_batch_data["points"]
        instance_masks = splitted_batch_data["instances"]

        gt_instance_masks = instance_masks.cpu().numpy()
        predicted_instance_masks = (
            torch.sigmoid(outputs["instances"]).detach().cpu().numpy()
        )
        points = points.detach().cpu().numpy()

        image_blob, points = images[0], points[0]
        gt_mask = np.squeeze(gt_instance_masks[0], axis=0)
        predicted_mask = np.squeeze(predicted_instance_masks[0], axis=0)

        image = image_blob.cpu().numpy() * 255
        image = image.transpose((1, 2, 0))

        image_with_points = draw_points(
            image,
            points[: self.max_initial_points],
            (0, 255, 0),
        )
        image_with_points = draw_points(
            image_with_points,
            points[self.max_initial_points :],
            (0, 0, 255),
        )

        gt_mask[gt_mask < 0] = 0.25
        gt_mask = draw_probmap(gt_mask)
        predicted_mask = draw_probmap(predicted_mask)
        viz_image = np.hstack((image_with_points, gt_mask, predicted_mask))
        viz_image = viz_image.astype(np.uint8)

        _save_image("instance_segmentation", viz_image[:, :, ::-1])


class UniformRandomResize(DualTransform):
    """Example of how to implement spatial augmentations.

    Note that we need to recalculate click (keypoint) positions!
    """

    def __init__(
        self,
        scale_range,
        interpolation=cv2.INTER_LINEAR,
        always_apply=True,
        p=1,
    ):
        super().__init__(always_apply, p)
        self.scale_range = scale_range
        self.interpolation = interpolation

    def get_params_dependent_on_targets(self, params):
        scale = random.uniform(*self.scale_range)
        height = int(round(params["image"].shape[0] * scale))
        width = int(round(params["image"].shape[1] * scale))
        return {"new_height": height, "new_width": width}

    def apply(
        self,
        img,
        new_height=0,
        new_width=0,
        interpolation=cv2.INTER_LINEAR,
        **params,
    ):
        resize_op = A.augmentations.geometric.resize.Resize(
            height=new_height,
            width=new_width,
            interpolation=interpolation,
        )
        return resize_op(image=img)["image"]

    def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params):
        scale_x = new_width / params["cols"]
        scale_y = new_height / params["rows"]
        keypoint = A.augmentations.geometric.functional.keypoint_scale(
            keypoint,
            scale_x,
            scale_y,
        )
        return keypoint

    def apply_to_bbox(self, *_, **__):
        raise NotImplementedError()

    def get_transform_init_args_names(self):
        return "scale_range", "interpolation"

    @property
    def targets_as_params(self):
        return ["image"]


def train_segmentation():
    input_size = (400, 400)
    model = ISModel(pretrained=True)

    cfg = SimpleNamespace()
    exp_path = Path("./experiments")
    cfg.EXP_PATH = exp_path
    cfg.CHECKPOINTS_PATH = exp_path / "checkpoints"
    cfg.VIS_PATH = exp_path / "vis"

    cfg.EXP_PATH.mkdir(parents=True, exist_ok=True)
    cfg.CHECKPOINTS_PATH.mkdir(exist_ok=True)
    cfg.VIS_PATH.mkdir(exist_ok=True)

    cfg.device = torch.device("cuda")

    cfg.max_initial_points = 24
    cfg.batch_size = 48
    cfg.val_batch_size = cfg.batch_size

    instance_loss = torch.nn.BCEWithLogitsLoss()

    # You can add more augmentations here
    h, w = input_size
    train_augmentator = Compose(
        [
            UniformRandomResize(scale_range=(0.75, 1.40)),
            PadIfNeeded(
                min_height=h,
                min_width=w,
                border_mode=cv2.BORDER_CONSTANT,
                value=0,
            ),
            RandomCrop(*input_size),
        ],
        p=1.0,
    )

    val_augmentator = Compose(
        [
            PadIfNeeded(
                min_height=h,
                min_width=w,
                border_mode=cv2.BORDER_CONSTANT,
                value=0,
            ),
            RandomCrop(*input_size),
        ],
        p=1.0,
    )

    points_sampler = MultiPointSampler(
        cfg.max_initial_points,
        prob_gamma=0.80,
        merge_objects_prob=0.15,
        max_num_merged_objects=2,
    )

    trainset = CocoLvisDataset(
        "./COCO_LVIS",
        split="train",
        augmentator=train_augmentator,
        min_object_area=1000,
        points_sampler=points_sampler,
        epoch_len=30000,
        stuff_prob=0.30,
    )

    valset = CocoLvisDataset(
        "./COCO_LVIS",
        split="val",
        augmentator=val_augmentator,
        min_object_area=1000,
        points_sampler=points_sampler,
        epoch_len=2000,
    )

    trainer = ISTrainer(
        model,
        cfg,
        instance_loss,
        trainset,
        valset,
        checkpoint_interval=5,
        image_dump_interval=1000,
        max_initial_points=cfg.max_initial_points,
        max_interactive_clicks=3,
    )

    trainer.run(num_epochs=10)


In [15]:
train_segmentation()

Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /tmp/xdg_cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth
100%|██████████| 21.1M/21.1M [00:00<00:00, 63.8MB/s]


FileNotFoundError: [Errno 2] No such file or directory: 'COCO_LVIS/train/hannotation.pickle'