In [None]:
import numpy as np

In [None]:
def bm3d(z: np.ndarray, sigma_psd: Union[np.ndarray, list, float],
         profile: Union[BM3DProfile, str] = 'np',
         stage_arg: Union[BM3DStages, np.ndarray] = BM3DStages.ALL_STAGES,
         blockmatches: tuple = (False, False))\
        -> Union[np.ndarray, Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]]:
    """
    Perform BM3D denoising on z: either hard-thresholding, Wiener filtering or both.

    :param z: Noisy image. either MxN or MxNxC where C is the channel count.
              For multichannel images, blockmatching is performed on the first channel.
    :param sigma_psd: Noise PSD, either MxN or MxNxC (different PSDs for different channels)
            or
           sigma_psd: Noise standard deviation, either float, or length C list of floats
    :param profile: Settings for BM3D: BM3DProfile object or a string
                    ('np', 'refilter', 'vn', 'vn_old', 'high', 'deb'). Default 'np'.
    :param stage_arg: Determines whether to perform hard-thresholding or wiener filtering.
                    either BM3DStages.HARD_THRESHOLDING, BM3DStages.ALL_STAGES or an estimate
                    of the noise-free image.
                    - BM3DStages.ALL_STAGES: Perform both.
                    - BM3DStages.HARD_THRESHOLDING: Perform hard-thresholding only.
                    - ndarray, size of z: Perform Wiener Filtering with stage_arg as pilot.
    :param blockmatches: Tuple (HT, Wiener), with either value either:
                        - False : Do not save blockmatches for phase
                        - True : Save blockmatches for phase
                        - Pre-computed block-matching array returned by a previous call with [True]
                        Such as y_est, matches = BM3D(z, sigma_psd, profile, blockMatches=(True, True))
                        y_est2 = BM3D(z2, sigma_psd, profile, blockMatches=matches);
    :return:
        - denoised image, same size as z: if blockmatches == (False, False)
        - denoised image, blockmatch data: if either element of blockmatches is True
    """


    # Ensure z is 3-D a numpy array
    z = np.array(z)
    if z.ndim == 1:
        raise ValueError("z must be either a 2D or a 3D image!")
    if z.ndim == 2:
        z = np.atleast_3d(z)

    # If profile defines maximum required pad, use that, otherwise use image size
    pad_size = (int(np.ceil(z.shape[0] / 2)), int(np.ceil(z.shape[1] / 2)))\
        if pro.max_pad_size is None else pro.max_pad_size

    # Add 3rd dimension if necessary - don't pad there
    if z.ndim == 3 and len(pad_size) == 2:
        pad_size = (pad_size[0], pad_size[1], 0)

    y_hat = None
    ht_blocks = None
    wie_blocks = None

    # If we passed a numpy array as stage_arg, presume it is a hard-thresholding estimate.
    if isinstance(stage_arg, np.ndarray):
        y_hat = np.atleast_3d(stage_arg)
        stage_arg = BM3DStages.WIENER_FILTERING
        if y_hat.shape != z.shape:
            raise ValueError("Estimate passed in stage_arg must be equal size to z!")

    elif stage_arg == BM3DStages.WIENER_FILTERING:
        raise ValueError("If you wish to only perform Wiener filtering, you need to pass an estimate as stage_arg!")

    if np.minimum(z.shape[0], z.shape[1]) < pro.bs_ht or np.minimum(z.shape[0], z.shape[1]) < pro.bs_wiener:
        raise ValueError("Image cannot be smaller than block size!")

    # If this is true, we are doing hard thresholding (whether we do Wiener later or not)
    stage_ht = (stage_arg.value & BM3DStages.HARD_THRESHOLDING.value) != 0
    # If this is true, we are doing Wiener filtering
    stage_wie = (stage_arg.value & BM3DStages.WIENER_FILTERING.value) != 0

    channel_count = z.shape[2]
    sigma_psd = np.array(sigma_psd)
    single_dim_psd = False

    # Format single dimension (std) sigma_psds
    if np.squeeze(sigma_psd).ndim <= 1:
        single_dim_psd = True
    if np.squeeze(sigma_psd).ndim == 1:
        sigma_psd = np.atleast_3d(np.ravel(sigma_psd)).transpose(0, 2, 1)
    else:
        sigma_psd = np.atleast_3d(sigma_psd)

    # Handle blockmatching inputs
    blockmatches_ht, blockmatches_wie = blockmatches  # Break apart

    # Convert blockmatch args to array even if they're single value
    if type(blockmatches_ht) == bool:
        blockmatches_ht = np.array([blockmatches_ht], dtype=np.int32)
    if type(blockmatches_wie) == bool:
        blockmatches_wie = np.array([blockmatches_wie], dtype=np.int32)

    sigma_psd2, psd_blur, psd_k = _process_psd(sigma_psd, z, single_dim_psd, pad_size, pro)

    # Step 1. Produce the basic estimate by HT filtering
    if stage_ht:

        # Get used transforms and aggregation windows.
        t_forward, t_inverse, hadper_trans_single_den,\
            inverse_hadper_trans_single_den, wwin2d = _get_transforms(pro, True)

        # Call the actual hard-thresholding step with the acquired parameters
        y_hat, ht_blocks = bm3d_step(BM3DStages.HARD_THRESHOLDING, z, psd_blur, single_dim_psd,
                                     pro, t_forward, t_inverse.T, hadper_trans_single_den,
                                     inverse_hadper_trans_single_den, wwin2d, channel_count, blockmatches_ht)
        if pro.print_info:
            print('Hard-thresholding stage completed')

        # Residual denoising, HT
        if pro.denoise_residual:

                remains, remains_psd = get_filtered_residual(z, y_hat, sigma_psd2, pad_size, pro.residual_thr)
                remains_psd = _process_psd_for_nf(remains_psd, psd_k, pro)

                if np.min(np.max(np.max(remains_psd, axis=0), axis=0)) > 1e-5:
                    # Re-filter
                    y_hat, ht_blocks = bm3d_step(BM3DStages.HARD_THRESHOLDING, y_hat + remains, remains_psd, False,
                                                 pro, t_forward, t_inverse.T, hadper_trans_single_den,
                                                 inverse_hadper_trans_single_den, wwin2d, channel_count,
                                                 blockmatches_ht, refiltering=True)

    # Step 2. Produce the final estimate by Wiener filtering (using the
    # hard-thresholding initial estimate)
    if stage_wie:

        # Get used transforms and aggregation windows.
        t_forward, t_inverse, hadper_trans_single_den,\
            inverse_hadper_trans_single_den, wwin2d = _get_transforms(pro, False)

        # Multiply PSDs by mus
        mu_list = np.ravel(pro.mu2).reshape([1, 1, np.size(pro.mu2)])
        if single_dim_psd:
            mu_list = np.sqrt(mu_list)
        psd_blur_mult = psd_blur * mu_list

        # Wiener filtering
        y_hat, wie_blocks = bm3d_step(BM3DStages.WIENER_FILTERING, z, psd_blur_mult, single_dim_psd,
                                      pro, t_forward, t_inverse.T, hadper_trans_single_den,
                                      inverse_hadper_trans_single_den, wwin2d, channel_count,
                                      blockmatches_wie, y_hat=y_hat)

        # Residual denoising, Wiener
        if pro.denoise_residual:
            remains, remains_psd = get_filtered_residual(z, y_hat, sigma_psd2, pad_size, pro.residual_thr)
            remains_psd = _process_psd_for_nf(remains_psd, psd_k, pro)

            if np.min(np.max(np.max(remains_psd, axis=0), axis=0)) > 1e-5:

                psd_blur_mult = remains_psd * np.ravel(pro.mu2_re).reshape([1, 1, np.size(pro.mu2_re)])
                y_hat, wie_blocks = bm3d_step(BM3DStages.WIENER_FILTERING, y_hat + remains, psd_blur_mult, False,
                                              pro, t_forward, t_inverse.T, hadper_trans_single_den,
                                              inverse_hadper_trans_single_den, wwin2d, channel_count, blockmatches_wie,
                                              refiltering=True, y_hat=y_hat)

        if pro.print_info:
            print('Wiener-filtering stage completed')

    if not stage_ht and not stage_wie:
        raise ValueError("No operation was selected!")

    # Remove useless dimension if only single output
    if channel_count == 1:
        y_hat = y_hat[:, :, 0]

    incl_blockmatches = (blockmatches_ht is not None and blockmatches_ht[0] == 1) or \
                        (blockmatches_wie is not None and blockmatches_wie[0] == 1)

    if blockmatches_ht is None or blockmatches_ht[0] != 1 or ht_blocks is None:
        ht_blocks = np.zeros(1, dtype=np.intc)

    if blockmatches_wie is None or blockmatches_wie[0] != 1 or wie_blocks is None:
        wie_blocks = np.zeros(1, dtype=np.intc)

    if incl_blockmatches:
        return y_hat, (ht_blocks, wie_blocks)

    return y_hat