# Image registration metrics for the unstain2stain images:
-Mutual Information
-Target Registration Error (SITK)

In [16]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import cv2
from matplotlib import pyplot as plt
from glob import glob
from skimage.registration import optical_flow_tvl1
from time import time
from skimage.transform import warp
import matplotlib

### Mutual Information (MI):

In [22]:
# from skin_2D_workflow\image_registration_evaluation to calculate mutual information:
# Function to calculate MI:
def calculate_mutual_information(img_path_1,img_path_2): #order doesn't matter, MI(A,B) = MI(B,A)
    img1 = np.array(cv2.imread(img_path_1))
    img2 = np.array(cv2.imread(img_path_2))
    img1_g = cv2.cvtColor(img1,cv2.COLOR_BGR2GRAY)
    img2_g = cv2.cvtColor(img2,cv2.COLOR_BGR2GRAY)
    hist2d, x_edges, y_edges = np.histogram2d(img1_g.ravel(),img2_g.ravel(),bins=20)
    pxy = hist2d / float(np.sum(hist2d))
    px = np.sum(pxy, axis=1)  # marginal x over y
    py = np.sum(pxy, axis=0)  # marginal y over x
    px_py = px[:, None] * py[None, :]  # broadcast to multiply marginals
    # now we can do the calculation using the pxy, px_py 2D arrays
    nonzeros = pxy > 0  # filer out the zero values
    mi = np.sum(pxy[nonzeros] * np.log(pxy[nonzeros] / px_py[nonzeros]))
    return mi

In [23]:
img_path_1 = glob(os.path.join(r"\\shelter\Kyu\unstain2stain\biomax_images\registrated_images","**","**.png"),recursive=True) #unstained, 174 images same order
img_path_2 = glob(os.path.join(r"\\shelter\Kyu\unstain2stain\biomax_images\stained\padded_images","**","**.png"),recursive=True) #stained, 174 images same order

In [24]:
# MI for each pair of image:
mi_ra = []
for idx in range(0,len(img_path_1)):
    img1 = img_path_1[idx]
    img2 = img_path_2[idx]
    mi = calculate_mutual_information(img1,img2)
    mi_ra.append(mi)
print(mi_ra)

[0.06960824587047423, 0.11743108445857318, 0.14049127955630492, 0.04516067106990143, 0.07190700019514393, 0.1026032203130491, 0.03944753233642199, 0.12487647436827311, 0.08166988417207607, 0.07407601452823742, 0.08675877265733534, 0.026672960834480484, 0.03882193563466624, 0.08350645432596476, 0.047932620308268954, 0.1557985326446766, 0.10315432003633795, 0.037578114212290056, 0.06191071062390437, 0.04281327678146191, 0.05740118990975115, 0.11081469698863479, 0.12363178346856792, 0.1531894553591467, 0.06883738548188653, 0.039376671331586136, 0.1283598570718828, 0.08570480086665128, 0.06703886482233695, 0.06059291555478167, 0.07624198946936489, 0.03936230216049576, 0.03796931610627313, 0.08174288018862663, 0.062358578765578675, 0.10845610990246536, 0.047642168106918764, 0.15492445500665886, 0.049275052229929325, 0.009416651217994598, 0.08069333541634434, 0.07923558485749368, 0.01876567784599072, 0.046131298466690014, 0.05467380562964635, 0.05526217760977068, 0.08458265097445196, 0.04745

# Because stained vs unstained image different modality, mutual information doesn't perform as well, as seen in the low scores above for all 174 images.

### Now try Target Registration Error (TRE):
### paper using tre and optical flow: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9099354/

In [2]:
def registrate_two_images(reference_image_path, moving_image_path, save_path):
    """
    Note: The order of the files saved in the ref_img_path and mov_img_path must be the same so that you are registering the same images!
    """
    ref_img_path = [_ for _ in os.listdir(reference_image_path) if _.endswith(".png")]
    ref_img_path_complete = [os.path.join(reference_image_path, x) for x in ref_img_path]
    mov_img_path = [_ for _ in os.listdir(moving_image_path) if _.endswith(".png")]
    mov_img_path_complete = [os.path.join(moving_image_path, x) for x in mov_img_path]
    mov_img_name = [x.replace('.png','') for x in mov_img_path]
    if int(len(ref_img_path)) != int(len(mov_img_path)):
        print("Number of images in reference and moving file paths are not equal, please fix and try again!")
        return

    start = time()
    for idx in range(0,len(ref_img_path_complete)):
        ref_img = Image.open(ref_img_path_complete[idx])
        mov_img = Image.open(mov_img_path_complete[idx])
        ref_img = np.array(ref_img)
        mov_img = np.array(mov_img)
        ref_img_g = cv2.cvtColor(ref_img,cv2.COLOR_RGBA2GRAY)
        mov_img_g = cv2.cvtColor(mov_img,cv2.COLOR_RGBA2GRAY)
        v, u = optical_flow_tvl1(ref_img_g, mov_img_g)
        nr, nc = ref_img_g.shape
        row_coords, col_coords = np.meshgrid(np.arange(nr), np.arange(nc),
                                             indexing='ij')

        mov_img_warp_ra =[]
        for i in range(3):
            #mov_img warped (2d array) = mov img (before warped) (2d array since each channel) warped by optical flow components v u of each axis (2 makes 2d), which is added to row_coords and col_coords of the ref img (fixed image).
            mov_img_warp = warp(mov_img[:,:,i], np.array([row_coords + v, col_coords + u]),mode='edge')
            mov_img_warp_ra.append(mov_img_warp)
        r = np.array(mov_img_warp_ra[0]*255).astype('uint8')
        g = np.array(mov_img_warp_ra[1]*255).astype('uint8')
        b = np.array(mov_img_warp_ra[2]*255).astype('uint8')
        rgb = np.stack([r,g,b],axis=2)
        reg_img = Image.fromarray(rgb)
        print(idx)
        reg_img.save(os.path.join(save_path,str(mov_img_name[idx]) + '.png'))

    end = time()
    print("time it took to register: "+  str((end-start)/60) + " minutes")


In [22]:
a = np.load(r'\\fatherserverdw\Kevin\imageregistration2\warp_arrays\u\z0001_1C1.npy')
a

array([[  39.70017 ,   39.700134,   39.700073, ..., -106.271645,
        -106.27186 , -106.27196 ],
       [  39.700184,   39.70015 ,   39.700085, ..., -106.2719  ,
        -106.2721  , -106.27221 ],
       [  39.700214,   39.700184,   39.700123, ..., -106.272385,
        -106.2726  , -106.272705],
       ...,
       [ 116.44678 ,  116.446785,  116.44678 , ..., -146.92422 ,
        -147.25272 , -147.42929 ],
       [ 116.44665 ,  116.44664 ,  116.44663 , ..., -147.00293 ,
        -147.26627 , -147.43987 ],
       [ 116.44658 ,  116.44658 ,  116.446556, ..., -147.03467 ,
        -147.27574 , -147.44449 ]], dtype=float32)

In [14]:
ra, ra1 = np.meshgrid(np.arange(10),np.arange(10),indexing="ij")
ra

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
       [5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
       [6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
       [7, 7, 7, 7, 7, 7, 7, 7, 7, 7],
       [8, 8, 8, 8, 8, 8, 8, 8, 8, 8],
       [9, 9, 9, 9, 9, 9, 9, 9, 9, 9]])

In [None]:
def target_registration_error(reference_image,moving_image, display_errors = False, min_err = None, max_err = None, figure_size=(8,6)):
    """
    Edited from: https://github.com/SimpleITK/SPIE2018_COURSE/blob/master/utilities.py
    """
    start = time()
    ref_img = Image.open(reference_image)
    mov_img = Image.open(moving_image)
    ref_img = np.array(ref_img)
    mov_img = np.array(mov_img)
    ref_img_g = cv2.cvtColor(ref_img,cv2.COLOR_RGBA2GRAY)
    mov_img_g = cv2.cvtColor(mov_img,cv2.COLOR_RGBA2GRAY)
    v, u = optical_flow_tvl1(ref_img_g, mov_img_g)
    points_ra = 2d array of some points from fixed image (np.meshgrid)
    points_list =
    transformed_points_ra = warp(points,np.array([v, u]),mode='edge')
    transformed_points_list =
    errors = #some list
    if display_errors:
        fig = plt.figure(figsize=figure_size)
        ax = fig.add_subplot(111, projection='3d')
    if not min_err:
        min_err = np.min(errors)
    if not max_err:
        max_err = np.max(errors)

    collection = ax.scatter(list(np.array(point_list).T)[0],
                            list(np.array(point_list).T)[1],
                            list(np.array(point_list).T)[2],
                            marker = 'o',
                            c = errors,
                            vmin = min_err,
                            vmax = max_err,
                            cmap = matplotlib.cm.hot,
                            label = 'original points')
    plt.colorbar(collection, shrink=0.8)
    plt.title('registration errors in mm', x=0.7, y=1.05)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    plt.show()

return errors

In [20]:
A = np.array([[2,2],[2,2]])
B = np.array([[1,1],[1,1]])
C = np.array([[3,3],[3,3]])
D = warp(A,A+B)
D

array([0., 0.])

In [3]:
def target_registration_errors(tx, point_list, reference_point_list,
                               display_errors = False, min_err= None, max_err=None, figure_size=(8,6)):
  """
  Distances between points transformed by the given transformation and their
  location in another coordinate system. When the points are only used to
  evaluate registration accuracy (not used in the registration) this is the
  Target Registration Error (TRE).
  Args:
      tx (SimpleITK.Transform): The transform we want to evaluate.
      point_list (list(tuple-like)): Points in fixed image
                                     coordinate system.
      reference_point_list (list(tuple-like)): Points in moving image
                                               cooredinate system.
      display_errors (boolean): Display a 3D figure with the points from
                                point_list color corresponding to the error.
      min_err, max_err (float): color range is linearly stretched between min_err
                                and max_err. If these values are not given then
                                the range of errors computed from the data is used.
      figure_size (tuple): Figure size in inches.
  Returns:
   (errors) [float]: list of TRE values.
  """
  transformed_point_list = [tx.TransformPoint(p) for p in point_list] # transform the fixed point from fixed image

  errors = [np.linalg.norm(np.array(p_fixed) -  np.array(p_moving)) # TRE = transformed point of fixed image - ref point of the moving image (the original point of fixed image and original point of moving image should be the "same" point of the image
            for p_fixed,p_moving in zip(transformed_point_list, reference_point_list)]
  if display_errors:
      from mpl_toolkits.mplot3d import Axes3D
      import matplotlib.pyplot as plt
      import matplotlib
      fig = plt.figure(figsize=figure_size)
      ax = fig.add_subplot(111, projection='3d')
      if not min_err:
          min_err = np.min(errors)
      if not max_err:
          max_err = np.max(errors)

      collection = ax.scatter(list(np.array(point_list).T)[0],
                              list(np.array(point_list).T)[1],
                              list(np.array(point_list).T)[2],
                              marker = 'o',
                              c = errors,
                              vmin = min_err,
                              vmax = max_err,
                              cmap = matplotlib.cm.hot,
                              label = 'original points')
      plt.colorbar(collection, shrink=0.8)
      plt.title('registration errors in mm', x=0.7, y=1.05)
      ax.set_xlabel('X')
      ax.set_ylabel('Y')
      ax.set_zlabel('Z')
      plt.show()

  return errors