In [None]:
def image_hist2d(image: torch.Tensor, min: float = 0., max: float = 255.,
                 n_bins: int = 256, bandwidth: float = -1.,
                 centers: torch.Tensor = torch.tensor([]), return_pdf: bool = False):
    """Function that estimates the histogram of the input image(s).

    The calculation uses triangular kernel density estimation.

    Args:
        x: Input tensor to compute the histogram with shape
        :math:`(H, W)`, :math:`(C, H, W)` or :math:`(B, C, H, W)`.
        min: Lower end of the interval (inclusive).
        max: Upper end of the interval (inclusive). Ignored when
        :attr:`centers` is specified.
        n_bins: The number of histogram bins. Ignored when
        :attr:`centers` is specified.
        bandwidth: Smoothing factor. If not specified or equal to -1,
        bandwidth = (max - min) / n_bins.
        centers: Centers of the bins with shape :math:`(n_bins,)`.
        If not specified or empty, it is calculated as centers of
        equal width bins of [min, max] range.
        return_pdf: If True, also return probability densities for
        each bin.

    Returns:
        Computed histogram of shape :math:`(bins)`, :math:`(C, bins)`,
        :math:`(B, C, bins)`.
        Computed probability densities of shape :math:`(bins)`, :math:`(C, bins)`,
        :math:`(B, C, bins)`, if return_pdf is ``True``. Tensor of zeros with shape
        of the histogram otherwise.
    """
    if not isinstance(image, torch.Tensor):
        raise TypeError(f"Input image type is not a torch.Tensor. Got {type(image)}.")

    if centers is not None and not isinstance(centers, torch.Tensor):
        raise TypeError(f"Bins' centers type is not a torch.Tensor. Got {type(centers)}.")

    if centers.numel() > 0 and centers.dim() != 1:
        raise ValueError(f"Bins' centers must be a torch.Tensor "
                         "of the shape (n_bins,). Got {values.shape}.")

    if not isinstance(min, float):
        raise TypeError(f'Type of lower end of the range is not a float. Got {type(min)}.')

    if not isinstance(max, float):
        raise TypeError(f"Type of upper end of the range is not a float. Got {type(min)}.")

    if not isinstance(n_bins, int):
        raise TypeError(f"Type of number of bins is not an int. Got {type(n_bins)}.")

    if bandwidth != -1 and not isinstance(bandwidth, float):
        raise TypeError(f"Bandwidth type is not a float. Got {type(bandwidth)}.")

    if not isinstance(return_pdf, bool):
        raise TypeError(f"Return_pdf type is not a bool. Got {type(return_pdf)}.")

    device = image.device

    if image.dim() == 4:
        batch_size, n_channels, height, width = image.size()
    elif image.dim() == 3:
        batch_size = 1
        n_channels, height, width = image.size()
    elif image.dim() == 2:
        height, width = image.size()
        batch_size, n_channels = 1, 1
    else:
        raise ValueError(f"Input values must be a of the shape BxCxHxW, "
                         f"CxHxW or HxW. Got {image.shape}.")

    if bandwidth == -1.:
        bandwidth = (max - min) / n_bins
    if centers.numel() == 0:
        centers = min + bandwidth * (torch.arange(n_bins, device=device).float() + 0.5)
    centers = centers.reshape(-1, 1, 1, 1, 1)
    u = abs(image.unsqueeze(0) - centers) / bandwidth
    mask = (u <= 1).float()
    hist = torch.sum(((1 - u) * mask), dim=(-2, -1)).permute(1, 2, 0)
    if return_pdf:
        normalization = torch.sum(hist, dim=-1).unsqueeze(0) + 1e-10
        pdf = hist / normalization
        return hist, pdf
    return hist, torch.zeros_like(hist, dtype=hist.dtype, device=device)