In [2]:
#| default_exp structures
#| export
import torch
from typing import Tuple, List, Dict, Any, Union, Optional
from torch import device
import warnings
import itertools

class Boxes3D:

    def __init__(self, tensor: torch.Tensor):

        """
        Args:
            tensor (Tensor[Float]): a Nx6 matrix where each row is (x1, y1, z1, x2, y2, z2)
        """

        if not isinstance(tensor, torch.Tensor):
            tensor = torch.as_tensor(tensor, dtype=torch.float32, device=torch.device("cpu"))
        else:
            tensor = tensor.to(torch.float32)

        if tensor.numel() == 0:
            tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32)

        assert tensor.dim() == 2 and tensor.size(-1) == 6, tensor.size()

        self.tensor = tensor
    
    def clone(self):

        """
        Clone the boxes

        Returns: 
            Boxes3D
        """

        return Boxes3D(self.tensor.clone())

    def to(self, device: torch.device):
        return Boxes3D(self.tensor.to(device=device))

    def volume(self):
        """
        Computes the area of all the boxes

        Returns:
            torch.Tensor: a vector with the area of each box
        """

        box = self.tensor
        volume = (box[:, 3] - box[:, 0]) * (box[:, 4] - box[:, 1]) * (box[:, 5] - box[:, 2])
        return volume
    
    def clip(self, box_size: Tuple[int, int]) -> None:

        """
        Clip (in place) the boxes by limiting x coordinates to the range [0, width]
        and y coordinates in the range [0, height].git/

        Args:
            box_size (height, width): The clipping box's size
        """

        assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or Nan"

        h, w, d = box_size
        x1 = self.tensor[:, 0].clamp(min=0, max=w)
        y1 = self.tensor[:, 1].clamp(min=0, max=h)
        z1 = self.tensor[:, 2].clamp(min=0, max=d)

        x2 = self.tensor[:, 3].clamp(min=0, max=w)
        y2 = self.tensor[:, 4].clamp(min=0, max=h)
        z2 = self.tensor[:, 5].clamp(min=0, max=d)

        self.tensor = torch.stack((x1, y1, z1, x2, y2, z2), dim=-1)

    def nonempty(self, threshold, float=0.0) -> torch.Tensor:
        """
        Find boxes that are non-empty.
        A box is considered empty, if either of its side is no larger than threshold.

        Returns:
            Tensor:
                a binary vector which represents whether each box is empty
                (False) or non-empty (True).
        """

        box = self.tensor
        widths = box[:, 2] - box[:, 0]
        heights = box[:, 3] - box[:, 1]
        keep = (widths > threshold) & (heights > threshold)
        return keep
    
    def __getitem__(self, item) -> "Boxes3D":
        """
        Args:
            item: int, slice, or a BoolTensor

        Returns:
            Boxes3D: Create a new :class:`Boxes3D` by indexing.

        The following usage are allowed:

        1. `new_boxes = boxes[3]`: return a `Boxes3D` which contains only one box.
        2. `new_boxes = boxes[2:10]`: return a slice of boxes.
        3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor
           with `length = len(boxes)`. Nonzero elements in the vector will be selected.

        Note that the returned Boxes might share storage with this Boxes,
        subject to Pytorch's indexing semantics.
        """

        if isinstance(item, int):
            return Boxes3D(self.tensor[item].view(1, -1))
        
        b = self.tensor[item]
        assert b.dim() == 2, "Indexing on Boxes3D with {} failed to return a matrix!".format(item)
        return Boxes3D(b)

    def __len__(self):
        return self.tensor.shape[0]

    def __repr__(self):
        return "Boxes3D(" + str(self.tensor) + ")"
    
    def inside_box(self, box_size: Tuple[int, int], boundary_threshold: int = 0) -> torch.Tensor:
        """
        Args:
            box_size (height, width): Size of the reference box.
            boundary_threshold (int): Boxes that extend beyond the reference box
                boundary by more than boundary_threshold are considered "outside".

        Returns:
            a binary vector, indicating whether each box is inside the reference box.
        """

        height, width, depth = box_size
        inds_inside = (
            (self.tensor[..., 0] >= -boundary_threshold)
            & (self.tensor[..., 1] >= -boundary_threshold)
            & (self.tensor[..., 2] >= -boundary_threshold)
            & (self.tensor[..., 3] < width + boundary_threshold)
            & (self.tensor[..., 4] < height + boundary_threshold) 
            & (self.tensor[..., 5] < depth + boundary_threshold) 
        )

        return inds_inside
        
    def get_centers(self) -> torch.Tensor:

        """
        Returns:
            the box center in Nx3 array of shape (x, y, z)
        """

        return (self.tensor[:, :3] + self.tensor[:, 3:])/2

    def scale(self, scale_x: float, scale_y: float, scale_z: float) -> None:

        """
        Scales the box with horizontal, vertical and depth scaling factors
        """

        self.tensor[:, 0::3] *= scale_x # x1, x2
        self.tensor[:, 1::3] *= scale_y # y1, y2
        self.tensor[:, 2::3] *= scale_z # z1, z2

    @classmethod
    def cat(cls, boxes_list: List["Boxes3D"]) -> "Boxes3D":

        """
        Concatenates a list of Boxes3D into a single Boxes3D
        """
        assert isinstance(boxes_list, (list, tuple))
        if len(boxes_list) == 0:
            return cls(torch.empty(0))

        assert all([isinstance(box, Boxes3D) for box in boxes_list])

        cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
        return cat_boxes

    @property
    def device(self) -> device:
        return self.tensor.device

    def __iter__(self):
        yield from self.tensor


class Instances3D:

    def __init__(self, image_size: Tuple[int, int], **kwargs):

        self._image_size = image_size
        self._fields: Dict[str, Any] = {}

        for k, v in kwargs.items():
            self.set(k, v)

    @property
    def image_size(self) -> Tuple[int, int]:
        return self._image_size

    def __setattr__(self, name: str, value: Any) -> None:
        if name.startswith("_"):
            super().__setattr__(name, value)
        else:
            self.set(name, value)

    def __getattr__(self, name: str) -> Any:
        if name == "_fields" or name not in self._fields:
            raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
        return self._fields[name]

    def set(self, name: str, value: Any) -> None:

        with warnings.catch_warnings(record=True):
            data_len = len(value)
        if len(self._fields):
            assert (
                len(self) == data_len
            ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
        self._fields[name] = value

    def has(self, name: str) -> bool:
        """
        Returns:
            bool: whether the field called name exists
        """

        return name in self._fields
    
    def clone(self) -> "Instances3D":

        """
        Deep copy of Instances3D.
        Matches detectron2 Instances.clone() semantics.
        """

        new = Instances3D(self._image_size)
        for k, v in self._fields.items():
            if hasattr(v, "clone"):
                new._fields[k] = v.clone()
            else:
                new._fields[k] = v

        return new

    def remove(self, name: str) -> None:
        del self._fields[name]

    def get(self, name: str) -> Any:
        return self._fields[name]

    def get_fields(self) -> Dict[str, Any]:
        return self._fields

    def to(self, *args: Any, **kwargs: Any) -> "Instances3D":
        ret = Instances3D(self._image_size)
        for k, v in self._fields.items():
            if hasattr(v, "to"):
                v = v.to(*args, **kwargs)
            ret.set(k, v)
        return ret

    def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances3D":
        """
        Args:
            item: an index-like object and will be used to index all the fields.

        Returns:
            If `item` is a string, return the data in the corresponding field.
            Otherwise, returns an `Instances` where all fields are indexed by `item`.
        """

        if type(item) is int:
            if item >= len(self) or item < -len(self):
                raise IndexError("Instances index out of range")
            else:
                item = slice(item, None, len(self))

        ret = Instances3D(self._image_size)
        for k, v in self._fields.items():
            ret.set(k, v[item])
        return ret

    def __len__(self) -> int:
        for v in self._fields.values():
            return v.__len__()

        raise NotImplementedError("Empty Instances does not supprt __len__")
    
    def __iter__(self):
        raise NotImplementedError("Instances3D is not iterable")

    @staticmethod
    def cat(instances_lists: List["Instances3D"]) -> "Instances3D":

        assert all(isinstance(i, Instances3D) for i in instances_lists)
        assert len(instances_lists) > 0

        if len(instances_lists) == 1:
            return instances_lists[0]

        image_size = instances_lists[0].image_size
        if not isinstance(image_size, torch.Tensor):
            for i in instances_lists[1:]:
                assert i.image_size == image_size

        ret = Instances3D(image_size)
        for k in instances_lists[0]._fields.keys():
            values = [i.get(k) for i in instances_lists]
            v0 = values[0]
            if isinstance(v0, torch.Tensor):
                values = torch.cat(values, dim=0)
            elif isinstance(v0, list):
                values = list(itertools.chain(*values))
            elif hasattr(type(v0), "cat"):
                values = type(v0).cat(values)
            else:
                raise ValueError("Unsupported type {} for concatentation".format(type(v0)))
            ret.set(k, values)

        return ret

    def __str__(self) -> str:
        s = self.__class__.__name__ + "("
        s += "num_instances={}, ".format(len(self))
        s += "image_height={}, ".format(self._image_size[0])
        s += "image_width={}, ".format(self._image_size[1])
        s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items())))
        return s

    __repr__ = __str__

def pairwise_intersection(boxes1: Boxes3D, boxes2: Boxes3D) -> torch.Tensor:

    boxes1, boxes2 = boxes1.tensor, boxes2.tensor
    depth_height_width = (
        torch.min(boxes1[:, None, 3:], boxes2[:, 3:]) - 
        torch.max(boxes1[:, None, :3], boxes2[:, :3])
    ) # (N, M, 3)

    depth_height_width.clamp_(min=0)

    intersection = depth_height_width.prod(dim=2) 
    return intersection

def pairwise_iou_3d(boxes1: Boxes3D, boxes2: Boxes3D) -> torch.Tensor:

    """
    Given two lists of boxes of sizes N and M computes the IoU 
    (intersection over Union) between **all** N x M pairs of boxes.
    The box order must be (xmin, ymin, xmax, ymax).
    """

    vol1 = boxes1.volume()
    vol2 = boxes2.volume()

    inter = pairwise_intersection(boxes1, boxes2)

    iou = torch.where(
        inter > 0,
        inter / (vol1[:, None] + vol2 - inter),
        torch.zeros(1, dtype=inter.dtype, device=inter.device)
    )

    return iou

In [3]:
#| export
class ImagesList3D:
    """
    Structure that holds a list of 3D images (volumes) of possibly
    varying sizes as a single padded tensor.

    Attributes:
        tensor (Tensor): shape (N, C, D, H, W)
        image_sizes (list[tuple[int, int, int]]): original (D, H, W) for each image
    """

    def __init__(
            self,
            tensor: torch.Tensor,
            image_sizes: List[Tuple[int, int, int]]
    ):
        
        """
        Args:
            tensor (Tensor): shape (N, C, D, H, W)
            image_sizes (list[(D, H, W)]): original sizes (before padding)
        """

        self.tensor = tensor
        self.image_sizes = image_sizes

    def __len__(self) -> int:
        return len(self.image_sizes)

    def __getitem__(self, idx) -> torch.Tensor:

        """
        Access individual volume in its original size.
        """

        d, h, w = self.image_sizes[idx]
        return self.tensor[idx, ..., :d, :h, :w]
    
    def to(self, *args: Any, **kwargs: Any) -> "ImagesList3D":
        return ImagesList3D(self.tensor.to(*args, **kwargs), self.image_sizes)

    @property
    def device(self) -> device:
        return self.tensor.device
    
    @staticmethod
    def from_tensors(
        tensors: List[torch.Tensor],
        size_divisibility: int = 0,
        pad_value: float = 0.0,
        padding_constraints: Optional[Dict[str, int]] = None,
    ) -> "ImagesList3D":

        """
        Args:
            tensors: list of Tensors of shape (C, D, H, W)
            size_divisibility: if > 0, pad D/H/W to be divisible by this
            pad_value: padding value

        Returns:
            ImageList3D
        """

        assert len(tensors) > 0
        assert isinstance(tensors, (list, tuple))

        for t in tensors:
            assert t.dim() == 4, f"Expected (C, D, H, W), got {t.shape}"
            assert t.shape[:-3] == tensors[0].shape[:-3]

        image_sizes = [(t.shape[-3], t.shape[-2], t.shape[-1]) for t in tensors]

        max_d = max(s[0] for s in image_sizes)
        max_h = max(s[1] for s in image_sizes)
        max_w = max(s[2] for s in image_sizes)

        # Handle padding constraints
        if padding_constraints is not None:
            cube_size = padding_constraints.get("cube_size", 0)
            if cube_size > 0:
                max_d = max_h = max_w = cube_size

            if "size_divisibility" in padding_constraints:
                size_divisibility = padding_constraints["size_divisibility"]

        if size_divisibility > 1:
            def _ceil(x, d): return ((x + d - 1) // d)*d

            max_d = _ceil(max_d, size_divisibility)
            max_h = _ceil(max_h, size_divisibility)
            max_w = _ceil(max_w, size_divisibility)
        
        batch_shape = (
            len(tensors),
            tensors[0].shape[0],
            max_d,
            max_h,
            max_w
        )

        batched = tensors[0].new_full(batch_shape, pad_value)

        for i, vol in enumerate(tensors):
            d, h, w = vol.shape[-3:]
            batched[i, :, :d, :h, :w].copy_(vol)

        return ImagesList3D(batched.contiguous(), image_sizes)