# **Overview**
This notebook benchmarks the MONAI's implementation of global mutual information ANTsPyx's implementation.

# **Global Mutual Information**
Mutual information is an entropy-based measure of image alignment derived from probabilistic measures of image intensity
values. Because a large number of image samples are used to estimate image statistics, the effects of image noise on the
metric are attenuated. Mutual information is also robust under varying amounts of image overlap as the test image moves
with respect to the reference. [1]

Formally, the mutual information between two images `A` and `B` is defined as the following

<img src=https://latex.codecogs.com/svg.image?I(a%2Cb)%26space%3B%3D%26space%3B%5Csum_%7Ba%2Cb%7D%26space%3Bp(a%2Cb)%26space%3B%5Clog(%5Cfrac%7Bp(a%2Cb)%7D%7Bp(a)p(b)%7D)>

where `a` and `b` respectively refers to intensity bin centers of `A` and `B`.

We used Parzen windowing in our implementation - given a set of `n` samples in image `A`, each sample `x` contributes to 
`p(a)` with a function of its intensity and the bin centre `a`:

<img src=https://latex.codecogs.com/svg.image?p(a)%3D%26space%3B%5Cfrac%7B1%7D%7Bn%7D%26space%3B%5Csum_%7Bx%26space%3B%5Cin%26space%3BA%7D%26space%3BW(x%2C%26space%3Ba)>

Similarly:

<img src=https://latex.codecogs.com/svg.image?p(b)%3D%26space%3B%5Cfrac%7B1%7D%7Bn%7D%26space%3B%5Csum_%7By%26space%3B%5Cin%26space%3BB%7D%26space%3BW(y%2C%26space%3Bb)>

To compute the joint distribution, we treat each sample as a pair of intensities of corresponding locations in the two images:

<img src=https://latex.codecogs.com/svg.image?p(a%2Cb)%26space%3B%3D%26space%3B%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7B(x%2Cy)%5Cin(A%2CB)%7D%26space%3BW(x%2Ca)W(y%2Cb)%26space%3B>


Two weighting functions - ``"gaussian"`` and ``"b-spline"`` - are provided. 
Here, we compare our ``"b-spline"`` method with the validated [ANTsPy](https://antspy.readthedocs.io/en/latest/) 
library.

>[1] "PET-CT Image Registration in the Chest Using Free-form Deformations"
D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank
IEEE Transactions in Medical Imaging. Vol.22, No.1,
January 2003. pp.120-128. 

# **Setup environment**

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"
!python -c "import ants" || pip install -q antspyx==0.2.9
!python -c "import plotly" || pip install -q plotly==5.3

In [None]:
import ants
import os
import tempfile
import torch
import plotly.graph_objects as go
import numpy as np
from monai import transforms
from monai.apps.utils import download_url
from monai.losses import GlobalMutualInformationLoss

In [None]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from monai.config import print_config

print_config()

# **Download data**

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(f"root dir is: {root_dir}")
file_url = "https://drive.google.com/uc?id=17tsDLvG_GZm7a4fCVMCv-KyDx0hqq1ji"
file_path = f"{root_dir}/Prostate_T2W_AX_1.nii"
download_url(file_url, file_path)

# **Comparison**
Both ANTsPy's and our implementation follows [1] - a third order BSpline kernel is used for the pred image intensity PDF
and a zero order (box car) BSpline kernel is used for the target image intensity PDF.

For benchmarking, we set the number of bins to 32, same as
[ANTsPy implementation](https://github.com/ANTsX/ANTsPy/blob/master/ants/lib/LOCAL_antsImageMutualInformation.cxx).

We took a lower-pelvic 3d MRI as `a1` and transformed it to get `a2` and report the 
Global Mutual Information between `a1` and `a2` derived with ANTsPy's and our implementation. 

Here, we first initialise a few functions necessary for comparison

In [None]:
def transformation(
        translate_params=(0., 0., 0.),
        rotate_params=(0., 0., 0.),
):
    """
    Read and transform Prostate_T2W_AX_1.nii
    Args:
        translate_params: a tuple of 3 floats, translation is in pixel/voxel relative to the center of the input image.
                Defaults to no translation.
        rotate_params: a rotation angle in radians, a tuple of 3 floats for 3D.
                Defaults to no rotation.
    Returns:
        numpy array of shape HWD
    """
    transform_list = [
        transforms.LoadImaged(keys="img"),
        transforms.Affined(
            keys="img",
            translate_params=translate_params,
            rotate_params=rotate_params,
            device=None,
        ),
        transforms.NormalizeIntensityd(keys=["img"])
    ]
    transformation = transforms.Compose(transform_list)
    return transformation({"img": file_path})["img"]

In [None]:
def get_result(a1, a2):
    """
    Calculate mutual information with both ANTsPyx and MONAI implementation
    Args:
        a1: numpy array of shape HWD
        a2: numpy array of shape HWD
    """
    antspyx_result = ants.image_mutual_information(
        ants.from_numpy(a1.detach().cpu().numpy()),
        ants.from_numpy(a2.detach().cpu().numpy())
    )
    monai_result = GlobalMutualInformationLoss(
        kernel_type="b-spline",
        num_bins=32,
        sigma_ratio=0.015
    )(
        torch.tensor(a1).unsqueeze(0).unsqueeze(0),
        torch.tensor(a2).unsqueeze(0).unsqueeze(0)
    ).item()
    return antspyx_result, monai_result

In [None]:
def plot(x, results, xaxis_title):
    """
    Plot diagram to compare ANTsPyx and MONAI result
    Args:
        x: list, x_axis values
        results: list of list
        xaxis_title: str
    """
    data = [
        go.Scatter(
            x=x,
            y=y,
            name=n,
            mode="lines+markers",
            line={'color': color, 'width': 1},
        )
        for y, n, color in zip(results, ['ANTsPy', 'MONAI'], ['coral', 'cornflowerblue'])
    ]
    fig = go.Figure(data=data)
    fig.update_layout(
        xaxis_title=xaxis_title,
        yaxis_title='MutualInformation',
        yaxis_range=[-2.0, 0.0]
    )
    fig.show()

In [None]:
def compare_antspyx_monai(transform_params_list, transform_name):
    """
    Args:
        transform_params_list: a list of tuple
        transform_name: str
    """
    antspyx_result = []
    monai_result = []
    # a1 is the original image without translation and rotation
    a1 = transformation((0., 0., 0.))

    for transform_params in transform_params_list:
        # translate/rotate the image to get a2
        a2 = transformation(
            translate_params=transform_params[0],
            rotate_params=transform_params[1]
        )
        a_r, m_r = get_result(a1, a2)
        antspyx_result.append(a_r)
        monai_result.append(m_r)

    # calculate the transformation euclidean_distance
    x = [np.linalg.norm(np.array(translation_param)) for translation_param in transform_params_list]
    # sort results by the transformation euclidean distance
    antspyx_result = [i for _, i in sorted(zip(x, antspyx_result))]
    monai_result = [i for _, i in sorted(zip(x, monai_result))]
    x = sorted(x)
    plot(
        x=x,
        results=[antspyx_result, monai_result],
        xaxis_title=transform_name,
    )

The following image visualises the 3d MRI after transformed by different translation params: 

![a](https://i.ibb.co/6X03szZ/translation-vis.png)

**Translation**

First, we incrementally increase the translation in all (x, y, z) directions by (1.0, 1.0, 1.0).

In [None]:
transform_params_list = [((i, i, i), (0., 0., 0.))for i in range(10)]
compare_antspyx_monai(transform_params_list, "xyz_translation")

Then, we translate in single directions by randomly sampled parameters.

In [None]:
transform_params_list = [((np.random.rand() * 10, 0., 0.), (0., 0., 0.))for i in range(10)]
compare_antspyx_monai(transform_params_list, "x_translation")

In [None]:
transform_params_list = [((0., np.random.rand() * 10, 0.), (0., 0., 0.))for i in range(10)]
compare_antspyx_monai(transform_params_list, "y_translation")

In [None]:
transform_params_list = [((0., 0., np.random.rand() * 10), (0., 0., 0.))for i in range(10)]
compare_antspyx_monai(transform_params_list, "z_translation")

**Rotation**

We also incrementally increase the rotation in all (x, y, z) directions by (1.0, 1.0, 1.0).

In [None]:
transform_params_list = [((0., 0., 0.), (np.pi / 100 * i, np.pi / 100 * i, np.pi / 100 * i))for i in range(10)]
compare_antspyx_monai(transform_params_list, "rotation")