### Experiment 5: Experiment with 3 depth stacks vs 2D

The goal of this experiment is to investigate whether the model would able to learn transformation given information of following and preceding depths.

To decide: how will output be determined, whether it will be a 3D or 1D image (of middle depth)

**Methods**: generate 3D patches of 3 consecutive depths.




In [None]:
# !pip install tifffile
# !pip install sklearn
# !pip install scikit-image
!pip install import_ipynb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# importing dependencies
import matplotlib.pyplot as plt
import sklearn
import sklearn.model_selection
import skimage
import math
import numpy as np
from sklearn import preprocessing
import tensorflow as tf
import math
import keras.backend as K
from datetime import datetime
import fractions
import itertools
import tqdm
from keras.utils.conv_utils import normalize_tuple

In [None]:
from google.colab import drive
import sys
import os
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# DATA_PATH = '/cluster/tufts/georgakoudi_lab01/tdinh02/npz/Cervix_all_data_original.npz'
# LABEL_PATH = '/cluster/tufts/georgakoudi_lab01/tdinh02/npz/Cervix_all_data_labels_original.npz'\
# MAIN_PATH = '/cluster/tufts/georgakoudi_lab01/tdinh02/objective_transform/'

DATA_PATH = '/npz/Cervix_all_data_3depths_3D.npz'
LABEL_PATH = '/npz/Cervix_all_data_labels_3depths_3D.npz'
MAIN_PATH= r"/objective_transfer/deep_learning" # Artem's Drive
now = datetime.now() # current date and time
CURR_DATE = now.strftime("%m-%d-%Y")

def config(patch_size, depth):
  return {
        'img_size': 512,
        'learning_rate': 1e-5,
        'batch_size': 16,
        'alpha': 0.84,
        'patch_size':patch_size,
        'input_shape': [patch_size, patch_size,3],
        'kern_size':3,
        'n_depth': depth,
        'first_depth': 32,
        'dropout': 0,
        'epoch':1,
        'lr_decay_factor':0.97,
        'lr_decay_patience':10,
  }

MODEL_ROOT_NAME = "CARE_patch_depth_tune_0726"

In [None]:
# importing yaml config file and resetting working direcgory
os.chdir(MAIN_PATH+'/CAREstd/')
import import_ipynb
from CARE_util import *

importing Jupyter notebook from CARE_util.ipynb


In [None]:
# def ssim(y_true, y_pred):
#     return (tf.image.ssim(y_true, y_pred,1,k2=0.05)) # sliding Gaussian window as mentioned in Wikipedia
# smaller filter size, should avoid blurring
def ssim(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, 1, filter_size=3, filter_sigma=0.5, k2=0.05)

w = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
def mssim(y_true, y_pred):
   return tf.image.ssim_multiscale(y_true, y_pred, 1, filter_size=11,power_factors=w, filter_sigma=1.5, k2=0.05)

def psnr(y_true, y_pred):
    '''
    Computs the peak signal-to-noise ratio between two images. Note that the
    maximum signal value is assumed to be 1.
    '''
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

def SSIM_loss(y_true, y_pred):
    return 1-((ssim(y_true, y_pred)+1)*0.5) # chanfe to 2

from scipy import signal
def cw_ssim(y_true, y_pred, width=30):
        """Compute the complex wavelet SSIM (CW-SSIM) value from the reference
        image to the target image.
        Args:
          target (str or PIL.Image): Input image to compare the reference image
          to. This may be a PIL Image object or, to save time, an SSIMImage
          object (e.g. the img member of another SSIM object).
          width: width for the wavelet convolution (default: 30)
        Returns:
          Computed CW-SSIM float value.
        """
        k = 0.01 #k (float): CW-SSIM configuration variable (default 0.01)
        # Define a width for the wavelet convolution
        widths = np.arange(1, width+1)

        # Use the image data as arrays
        sig1 = y_pred.numpy().flatten()
        sig2 = y_true.numpy().flatten()

        # Convolution
        cwtmatr1 = signal.cwt(sig1, signal.ricker, widths)
        cwtmatr2 = signal.cwt(sig2, signal.ricker, widths)

        # Compute the first term
        c1c2 = np.multiply(abs(cwtmatr1), abs(cwtmatr2))
        c1_2 = np.square(abs(cwtmatr1))
        c2_2 = np.square(abs(cwtmatr2))
        num_ssim_1 = 2 * np.sum(c1c2, axis=0) + k
        den_ssim_1 = np.sum(c1_2, axis=0) + np.sum(c2_2, axis=0) + k

        # Compute the second term
        c1c2_conj = np.multiply(cwtmatr1, np.conjugate(cwtmatr2))
        num_ssim_2 = 2 * np.abs(np.sum(c1c2_conj, axis=0)) + k
        den_ssim_2 = 2 * np.sum(np.abs(c1c2_conj), axis=0) + k

        # Construct the result
        ssim_map = (num_ssim_1 / den_ssim_1) * (num_ssim_2 / den_ssim_2)

        # Average the per pixel results
        index = np.average(ssim_map)
        return index

def SSIML1_loss(y_true, y_pred, alpha=0.84):
  # alpha = 0.84
  ssim_partial = 1-((ssim(y_true, y_pred)+1)*0.5)
  mae_partial = tf.keras.losses.mae(
        *[tf.keras.backend.batch_flatten(y) for y in [y_true, y_pred]])
  return alpha*ssim_partial  + (1-alpha)*mae_partial

def mov_var(image):
  dtype = tf.float32
  img_height = tf.shape(image)[1]
  img_width = tf.shape(image)[2]
  mean_filter = tf.ones((3,3),dtype) / 9
  img_mean = tf.nn.conv2d(image[:,:,:,:],
                          mean_filter[:,:,tf.newaxis,tf.newaxis],
                          [1,1,1,1],'VALID')
  img_clip = image[:, 1:-1, 1:-1,:]
  # Difference between pixel intensity and its block mean
  x_diff = tf.math.squared_difference(img_clip, img_mean) / 8
  return x_diff

def genSSIML1_loss(alpha=0.84):
  def SSIM_L1_loss(y_true, y_pred):
    ssim_partial = 1-((ssim(y_true, y_pred)+1)*0.5)
    mae_partial = tf.keras.losses.mae(
          *[tf.keras.backend.batch_flatten(y) for y in [y_true, y_pred]])
    return alpha*ssim_partial  + (1-alpha)*mae_partial
  return SSIM_L1_loss

def genSSIMVar_loss(alpha=0.84):
  def SSIMVar_loss(y_true, y_pred):
      SSIM = 1-((ssim(y_true, y_pred)+1)*0.5)
      MAE = tf.keras.losses.mae(
          *[tf.keras.backend.batch_flatten(y) for y in [mov_var(y_true), mov_var(y_pred)]])
      return alpha * SSIM + (1-alpha) * MAE
  return SSIMVar_loss

def genSSIMVarL1_loss(alpha=0.84):
  def SSIMVarL1_loss(y_true, y_pred):
      SSIM = 1-((ssim(y_true, y_pred)+1)*0.5)
      MAE = tf.keras.losses.mae(
          *[tf.keras.backend.batch_flatten(y) for y in [mov_var(y_true), mov_var(y_pred)]])
      MAE2 = tf.keras.losses.mae(
          *[tf.keras.backend.batch_flatten(y) for y in [y_true, y_pred]])
      return alpha * SSIM + ((1-alpha)/4) * MAE + (3*(1-alpha)/4) * MAE2
  return SSIMVarL1_loss

In [None]:
# HELPER FUNCTIONS

def create_patches(img, patch_shape, slide, depth=3):
    # returns stack of patches and number of patches
    patch_img = skimage.util.view_as_windows(img, (patch_shape,patch_shape,depth), step=patch_shape-slide)
    patch = patch_img.reshape(patch_img.shape[0]*patch_img.shape[1],patch_shape,patch_shape,depth) # more time efficient
    return patch

def patchify(input, patch_shape, slide, depth=3):
    # getting number of input images
    len_to_allocate = int(np.shape(input)[0]*((slide-np.shape(input)[1]) / (slide-patch_shape))**2)
    data = np.zeros((len_to_allocate,patch_shape,patch_shape,depth))
    count = 0
    for i in range(np.shape(input)[0]):
      A = create_patches(input[i], patch_shape, slide)
      # print("[:,:,:]A", np.amax(A[:,:,:]))
      data[count:count+len(A),:,:] = A[:,:,:]
      # print("data[count:count+len(A),:,:]", np.amax(data[count:count+len(A),:,:]))
      count = count + len(A)
    print("      [PATCHIFYING COMPLETED] output shape, slide: ",np.shape(data),slide,"; number of images: ", np.shape(input)[0], ", number of patches: ", np.shape(data)[0])
    return data


def load_data(path, expand=False, patch=None):
    # Loading preprocessed image patches and adding 4th arbitrary dimension
    b = np.load(path)
    training_data = b['t']
    val_data = b['v']

    # patchify if patch size is passed
    if patch != None:
      slide = int(patch/2)
      # slide = 0
      training_data = patchify(training_data, patch, slide)
      val_data = patchify(val_data, patch, slide)

    # if len(np.shape(training_data))>3:
    #   # for 3D only, changing dimensions
    #   training_data = np.transpose(training_data, (0, 3, 1, 2,4))
    #   val_data = np.transpose(val_data, (0, 3, 1, 2,4))
    if expand == True:
      training_data = expandLastDim(training_data)
      val_data = expandLastDim(val_data)
    res = [training_data, val_data]
    return res

def expandLastDim(data):
    return np.expand_dims(data, -1)

# add arbitrary channel dimension and augments data on flow, also splits data into batches
class DataGenerator:
    '''
    Generates batches of image pairs with real-time data augmentation.
    Parameters
    ----------
    shape: tuple of int
        Shape of batch images (excluding the channel dimension).
    batch_size: int
        Batch size.
    transform_function: str or callable or None
        Function used for data augmentation. Typically you will set
        ``transform_function='rotate_and_flip'`` to apply combination of
        randomly selected image rotation and flipping.  Alternatively, you can
        specify an arbitrary transformation function which takes two input
        images (source and target) and returns transformed images. If
        ``transform_function=None``, no augmentation will be performed.
    intensity_threshold: float
        If ``intensity_threshold > 0``, pixels whose intensities are greater
        than this threshold will be considered as foreground.
    area_ratio_threshold: float between 0 and 1
        If ``intensity_threshold > 0``, the generator calculates the ratio of
        foreground pixels in a target patch, and rejects the patch if the ratio
        is smaller than this threshold.
    scale_factor: int != 0
        Scale factor for the target patch size. Positive and negative values
        mean up- and down-scaling respectively.
    '''
    def __init__(self,
                 shape,
                 batch_size,
                 transform_function='rotate_and_flip',
                 intensity_threshold=0.0,
                 area_ratio_threshold=0.0,
                 scale_factor=1):
        def rotate_and_flip(x, y, dim):
            if dim == 2:
                k = np.random.randint(0, 4)
                x, y = [np.rot90(v, k=k) for v in (x, y)]
                if np.random.random() < 0.5:
                    x, y = [np.fliplr(v) for v in (x, y)]
                return x, y
            elif dim == 3:
                k = np.random.randint(0, 4)
                x, y = [np.rot90(v, k=k, axes=(1, 2)) for v in (x, y)]
                if np.random.random() < 0.5:
                    x, y = [np.flip(v, axis=1) for v in (x, y)]
                if np.random.random() < 0.5:
                    x, y = [np.flip(v, axis=0) for v in (x, y)]
                return x, y
            else:
                raise ValueError('Unsupported dimension')

        self._shape = tuple(shape)
        self._batch_size = batch_size

        dim = len(self._shape)

        if transform_function == 'rotate_and_flip':
            if shape[-2] != shape[-1]:
                raise ValueError(
                    'Patch shape must be square when using `rotate_and_flip`; '
                    f'Received shape: {shape}')
            self._transform_function = lambda x, y: rotate_and_flip(x, y, dim)
        elif callable(transform_function):
            self._transform_function = transform_function
        elif transform_function is None:
            self._transform_function = lambda x, y: (x, y)
        else:
            raise ValueError('Invalid transform function')

        self._intensity_threshold = intensity_threshold

        if not 0 <= area_ratio_threshold <= 1:
            raise ValueError('"area_ratio_threshold" must be between 0 and 1')
        self._area_threshold = area_ratio_threshold * np.prod(shape)

        self._scale_factor = normalize_tuple(scale_factor, dim, 'scale_factor')
        if any(not isinstance(f, int) or f == 0 for f in self._scale_factor):
            raise ValueError('"scale_factor" must be nonzero integer')

    class _Sequence(tf.keras.utils.Sequence):
        def _scale(self, shape):
            return tuple(
                s * f if f > 0 else s // -f
                for s, f in zip(shape, self._scale_factor))

        def __init__(self,
                     x,
                     y,
                     batch_size,
                     shape,
                     transform_function,
                     intensity_threshold,
                     area_threshold,
                     scale_factor):
            self._batch_size = batch_size
            self._transform_function = transform_function
            self._intensity_threshold = intensity_threshold
            self._area_threshold = area_threshold
            self._scale_factor = scale_factor

            for s, f, in zip(shape, self._scale_factor):
                if f < 0 and s % -f != 0:
                    raise ValueError(
                        'When downsampling, all elements in `shape` must be '
                        'divisible by the scale factor; '
                        f'Received shape: {shape}, '
                        f'scale factor: {self._scale_factor}')

            self._x, self._y = [
                list(m) if isinstance(m, (list, tuple)) else [m]
                for m in [x, y]]
            self._x = np.moveaxis(self._x,0,-1)
            self._y = np.moveaxis(self._y,0,-1)
            if len(self._x) != len(self._y):
                raise ValueError(
                    'Different number of images are given: '
                    f'{len(self._x)} vs. {len(self._y)}')

            if len({m.dtype for m in self._x}) != 1:
                raise ValueError('All source images must be the same type')
            if len({m.dtype for m in self._y}) != 1:
                raise ValueError('All target images must be the same type')
            print(len(self._x))
            for i in range(len(self._x)):
                if len(self._x[i].shape) == len(shape):
                    self._x[i] = self._x[i][..., np.newaxis]

                if len(self._y[i].shape) == len(shape):
                    self._y[i] = self._y[i][..., np.newaxis]

                if len(self._x[i].shape) != len(shape) + 1:
                    raise ValueError(f'Source image must be {len(shape)}D')

                if len(self._y[i].shape) != len(shape) + 1:
                    raise ValueError(f'Target image must be {len(shape)}D')
                if self._x[i].shape[:-1] < shape:
                    raise ValueError(
                        'Source image must be larger than the patch size')

                expected_y_image_size = self._scale(self._x[i].shape[:-1])
                if self._y[i].shape[:-1] != expected_y_image_size:
                    raise ValueError('Invalid target image size: '
                                     f'expected {expected_y_image_size}, '
                                     f'but received {self._y[i].shape[:-1]}')

            if len({m.shape[-1] for m in self._x}) != 1:
                raise ValueError(
                    'All source images must have the same number of channels')
            if len({m.shape[-1] for m in self._y}) != 1:
                raise ValueError(
                    'All target images must have the same number of channels')
            self._batch_x = np.zeros(
                (batch_size, *shape, self._x[0].shape[-1]),
                dtype=self._x[0].dtype)
            self._batch_y = np.zeros(
                (batch_size, *self._scale(shape),self._y[0].shape[-1]),
                dtype=self._y[0].dtype)

        def __len__(self):
            return len(self._x) // self._batch_size # return a dummy value

        def __next__(self):
            return self.__getitem__(0)

        def __getitem__(self, _):
            for i in range(self._batch_size):
                for _ in range(139):
                    j = np.random.randint(0, len(self._x))

                    tl = [np.random.randint(0, a - b + 1)
                          for a, b in zip(
                              self._x[j].shape, self._batch_x.shape[1:])]

                    x = np.copy(self._x[j][tuple(
                        [slice(a, a + b) for a, b in zip(
                            tl, self._batch_x.shape[1:])])])
                    y = np.copy(self._y[j][tuple(
                        [slice(a, a + b) for a, b in zip(
                            self._scale(tl), self._batch_y.shape[1:])])])

                    if (self._intensity_threshold <= 0.0 or
                            np.count_nonzero(y > self._intensity_threshold)
                            >= self._area_threshold):
                        break
                else:
                    import warnings
                    warnings.warn(
                        'Failed to sample a valid patch',
                        RuntimeWarning,
                        stacklevel=3)


                self._batch_x[i], self._batch_y[i] = \
                    self._transform_function(x, y)
            return self._batch_x, self._batch_y

    def flow(self, x, y):
        '''
        Returns a `keras.utils.Sequence` object which generates batches
        infinitely. It can be used as an input generator for
        `keras.models.Model.fit_generator()`.
        Parameters
        ----------
        x: array_like or list of array_like
            Source image(s).
        y: array_like or list of array_like
            Target image(s).
        Returns
        -------
        keras.utils.Sequence
            `keras.utils.Sequence` object which generates tuples of source and
            target image patches.
        '''
        return self._Sequence(x,
                              y,
                              self._batch_size,
                              self._shape,
                              self._transform_function,
                              self._intensity_threshold,
                              self._area_threshold,
                              self._scale_factor)

# CALLBACKS
# tensorboard callback
tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1,write_images=True,write_steps_per_second=True,)


def build_model(n_dim=3, n_depth=2, kern_size=5, dropout= 0, n_first=32, n_channel_out=3, last_activation='relu', batch_norm=False,shape=(None,None,3), residual=True):
    model = common_unet(n_dim,n_depth,n_first=n_first,kern_size=kern_size,n_channel_out=n_channel_out, dropout=dropout,last_activation=last_activation,batch_norm=batch_norm,residual=residual)(shape)
    model.summary()
    model.save_weights('model.h5')
    return model

# Compile model for training
def compile_model(model, lr, loss=SSIML1_loss, metric=[{'psnr': psnr, 'ssim': ssim}]):
    model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = lr),
                loss=loss,
                # loss=tf.keras.losses.MeanAbsoluteError(),
                metrics = metric)
                # metrics = [{'mae': tf.keras.losses.mae}])
    return model

# Create a callback that saves the model's weights every 5 epochs
def save_weight_callback(checkpoint_path, batch_size, epoch_freq=10):
  return tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    verbose=1,
    save_weights_only=True,
    save_freq=epoch_freq*batch_size)

def display_training_stats(history, savepath, stat, axis_name):
    plt.figure(figsize=(10,10))
    plt.plot(history.history[stat], label=stat)
    plt.plot(history.history['val_'+stat], label = 'val_'+stat)
    plt.xlabel(axis_name)
    plt.ylabel(stat)
    plt.legend(loc='upper right')
    plt.savefig(savepath+'_'+stat+'.png')

In [None]:
def _get_gaussian_kernel(dim, size, sigma):
    import tensorflow_probability as tfp
    k = size // 2
    normal = tfp.distributions.Normal(0.0, sigma)
    p = normal.prob(tf.range(-k, size - k, dtype=tf.float32))

    indices = [chr(i) for i in range(105, 105 + dim)]
    eq = ','.join(indices) + '->' + ''.join(indices)
    kernel = tf.einsum(eq, *([p] * dim))
    kernel /= tf.reduce_sum(kernel)
    kernel = kernel[..., tf.newaxis, tf.newaxis]

    return kernel

# RCAN ssim
import keras.backend as K
def ssim_rcan(y_true, y_pred):
    '''
    Computes the structural similarity index between two images. Note that the
    maximum signal value is assumed to be 1.
    References
    ----------
    Image Quality Assessment: From Error Visibility to Structural Similarity
    https://doi.org/10.1109/TIP.2003.819861
    '''

    c1 = 0.01 ** 2
    c2 = 0.03 ** 2

    dim = K.ndim(y_pred) - 2
    if dim not in (2, 3):
        raise NotImplementedError(f'{dim}D SSIM is not suported')

    num_channels = K.int_shape(y_pred)[-1]

    kernel = _get_gaussian_kernel(dim, 11, 1.5)
    conv = K.conv2d if dim == 2 else K.conv3d

    def average(x):
        # channel-wise weighted average using the Gaussian kernel
        return tf.concat(
            [conv(y, kernel) for y in tf.split(x, num_channels, axis=-1)],
            axis=-1)

    ux = average(y_true)
    uy = average(y_pred)

    a = ux * uy
    b = K.square(ux) + K.square(uy)
    c = average(y_true * y_pred)
    d = average(K.square(y_true) + K.square(y_pred))

    lum = (2 * a + c1) / (b + c1)
    cs = (2 * (c - a) + c2) / (d - b + c2)

    return K.mean(K.batch_flatten(lum * cs), axis=-1)

def revert_img(img,original_size, patch_shape, slide):
  # reverts original image and removes overlaps by splitting overlap over 2 images
  step = int(patch_shape-slide)
  reconstructed_arr = np.zeros((original_size,original_size))
  for x in range(img.shape[0]):
    for y in range(img.shape[1]):
      start_x = int(slide/2)
      start_y = int(slide/2)
      end_x = 0
      end_y = 0
      if x == 0:
        start_x = 0
        end_x = int(slide/2)
      if y == 0:
        start_y = 0
        end_y = int(slide/2)
      if x == img.shape[0]-1: end_x = int(slide/2)
      if y == img.shape[1]-1: end_y = int(slide/2)
      x_pos, y_pos = x * step + start_x, y * step + start_y
      reconstructed_arr[x_pos : x_pos + step + end_x, y_pos : y_pos + step + end_y] = img[x, y, start_x:start_x+step+end_x, start_y:start_y+step+end_y]
  return reconstructed_arr

def merge_patches(img, original_size, patch_shape, slide):
  #  merging patches, img is a 3D array of stacked patches
  row_len = int(math.sqrt(img.shape[0]))
  patches = np.zeros((row_len,row_len,patch_shape,patch_shape))
  for r in range(row_len):
      patches[r,:,:,:] = img[r*row_len:r*row_len+row_len,:,:]
  return revert_img(patches,original_size,patch_shape, slide)

def normalize0to1(img):
  # Normalizing images between 0 and 1 and preserving distribution
  img_norm = (img - np.amin(img))/( np.amax(img)- np.amin(img))
  return img_norm

def mergeAndPredict(model,raw,gt,start, full_size, patch_size, overlap, merge=True, normalize=True, clip=False):
  end = start+int((full_size - overlap) / (patch_size - overlap))**2
  # outputs a 3D list of 3 images in order raw, predicted, ground truth
  pred = model.predict(raw[start:end])
  print("predicted results shape:  ",np.shape(pred))
  if len(np.shape(raw)) > 3:
    if merge:
      result = [merge_patches(raw[start:end,:,:,1],full_size,patch_size,overlap), merge_patches(pred[0:end-start,:,:,0],full_size,patch_size,overlap), merge_patches(gt[start:end,:,:,1],full_size,patch_size,overlap)]
    else: result = [raw[start:end,:,:,1], pred[0:end-start,:,:,0],gt[start:end,:,:,1]]
  # else:
  #   if merge:
  #     result = [merge_patches(raw[start:end,:,:],full_size,patch_size,overlap), merge_patches(pred[start:end,:,:,0],full_size,patch_size,overlap), merge_patches(gt[start:end,:,:],full_size,patch_size,overlap)]
  #   else: result = [raw[start:end,:,:], pred[start:end,:,:,0], gt[start:end,:,:]]
  if normalize: result = [normalize0to1(m) for m in result]
  if clip: result = [np.clip(255 * m, 0, 255).astype('uint8') for m in result]
  plot(np.concatenate((np.concatenate((result[0],result[1]),axis=1), result[2]),axis=1), "raw, predicted, gt image", size=(30,30))
  return result

def evalSSIM(result):
  result1 = [np.expand_dims(m,-1) for m in result]
  raw = tf.convert_to_tensor(np.expand_dims(result1[0],0),dtype=np.float32)
  rest = tf.convert_to_tensor(np.expand_dims(result1[1],0),dtype=np.float32)
  gt = tf.convert_to_tensor(np.expand_dims(result1[2],0),dtype=np.float32)
  # Metric evaluation
  print('raw vs gt=======================')
  print("Our SSIM between raw and ground truth: ",ssim(gt,raw).numpy())
  print("RCAN paper SSIM between raw and ground truth: ",ssim_rcan(gt, raw).numpy())
  print('predicted vs gt=======================')
  print("Our SSIM between predicted and ground truth: ",ssim(gt, rest).numpy())
  print("RCAN paper SSIM between predicted and ground truth: ",ssim_rcan(gt,rest).numpy())
  print('predicted vs raw=======================')
  print("Our SSIM between predicted and raw: ",ssim(raw, rest).numpy())
  print("RCAN paper SSIM between predicted and raw: ",ssim_rcan(raw, rest).numpy())



def plot(img, label, hist=False, size=(10,10)):
  plt.figure(figsize = size)
  if hist==True:
    plt.hist(img.flatten(), bins=120)
  else: plt.imshow(img,cmap="gray")
  plt.title(label)




In [None]:
from numpy.lib.shape_base import expand_dims
### Experiment 5: Experiment with 3 depth stacks vs 2D

# The goal of this experiment is to investigate whether the model would able to learn transformation given information of following and preceding depths.

# To decide: how will output be determined, whether it will be a 3D or 1D image (of middle depth)
def mssim(y_true, y_pred):
   return tf.image.ssim_multiscale(y_true, y_pred, 1, filter_size=3,power_factors=w, filter_sigma=0.5, k2=0.05)
# **Methods**: generate 3D patches of 3 consecutive depths.
def genSSIML1_3D_loss(alpha=0.84):
  # this loss takes in 3D patch of dimension (n,n,3) and calculate loss on middle patch only
  def SSIM_L1_loss(y_true, y_pred):
    y_true =tf.expand_dims(y_true[...,1], -1)
    ssim_partial = 1-((mssim(y_true, y_pred)+1)*0.5)
    mae_partial = tf.keras.losses.mae(
          *[tf.keras.backend.batch_flatten(y) for y in [y_true, y_pred]])
    # print("partials ssim l1: ",ssim_partial,mae_partial)

    # adding l2 regulizer
    l2_norms = [tf.nn.l2_loss(v) for v in model.trainable_variables]
    l2_norm = tf.reduce_sum(l2_norms)
    lambda_ = 0.1
    return alpha*ssim_partial  + (1-alpha)*mae_partial + lambda_*l2_norm
  return SSIM_L1_loss

def genSSIMVar_3D_loss(alpha=0.84):
  # this loss takes in 3D patch of dimension (n,n,3) and calculate loss on middle patch only
  def SSIMVar_loss(y_true, y_pred):
    y_true =tf.expand_dims(y_true[...,1], -1)
    SSIM = 1-((mssim(y_true, y_pred)+1)*0.5)
    MAE = tf.keras.losses.mae(
          *[tf.keras.backend.batch_flatten(y) for y in [mov_var(y_true), mov_var(y_pred)]])
    return alpha * SSIM + (1-alpha) * MAE
  return SSIMVar_loss

# def genSSIMVar_loss(alpha=0.84):
#   def SSIMVar_loss(y_true, y_pred):
#       SSIM = 1-((ssim(y_true, y_pred)+1)*0.5)
#       MAE = tf.keras.losses.mae(
#           *[tf.keras.backend.batch_flatten(y) for y in [mov_var(y_true), mov_var(y_pred)]])
#       return alpha * SSIM + (1-alpha) * MAE
#   return SSIMVar_loss

def ssim_3d_metric(y_true, y_pred):
  y_true =tf.expand_dims(y_true[...,1], -1)
  return ssim(y_true, y_pred)
def psnr_3d_metric(y_true, y_pred):
  y_true =tf.expand_dims(y_true[...,1], -1)
  return psnr(y_true, y_pred)
def mae_3d_metric(y_true, y_pred):
  y_true =tf.expand_dims(y_true[...,1], -1)
  return tf.keras.losses.mae(y_true, y_pred)

def config(patch_size, depth):
  # return {
  #       'img_size': 512,
  #       'learning_rate': 1e-3,
  #       'batch_size': 16,
  #       'alpha': 0.84,
  #       'patch_size':patch_size,
  #       'input_shape': [patch_size, patch_size],
  #       'kern_size':3,
  #       'n_depth': depth,
  #       'first_depth': 32,
  #       'dropout': 0,
  #       'epoch':100,
  #       'lr_decay_factor':0.5,
  #       'lr_decay_patience':3,
  # }
  # return {
  #     'img_size': 512,
  #     'learning_rate': 1e-4,
  #     'batch_size': 16,
  #     'alpha': 0.9,
  #     'patch_size':patch_size,
  #     'input_shape': [patch_size, patch_size],
  #     'kern_size':3,
  #     'kern_sigma': 0.99,
  #     'n_depth': depth,
  #     'first_depth': 32,
  #     'dropout': 0,
  #     'epoch':200,
  #     'lr_decay_factor':0.5,
  #     'lr_decay_patience':10,
  #     'loss': genSSIML1_3D_loss
  # }
  return {
      'img_size': 512,
      'learning_rate': 1e-4,
      'batch_size': 16,
      'alpha': 0.6,
      'patch_size':patch_size,
      'input_shape': [patch_size, patch_size],
      'kern_size':3,
      'kern_sigma': 0.99,
      'n_depth': depth,
      'first_depth': 32,
      'dropout': 0,
      'epoch':200,
      'lr_decay_factor':0.99,
      'lr_decay_patience':5,
      'loss': genSSIMVar_3D_loss
  }

In [None]:
config_train = config(256, 6)

# generating data
# data_gen = DataGenerator(
#     config_train['input_shape'],
#     config_train['batch_size'],
#     transform_function=None)
[training_data, val_data] = load_data(MAIN_PATH+DATA_PATH, patch=config_train['patch_size'])
[training_data_labels,val_data_labels] = load_data(MAIN_PATH+LABEL_PATH,patch=config_train['patch_size'])


      [PATCHIFYING COMPLETED] output shape, slide:  (2610, 256, 256, 3) 128 ; number of images:  290 , number of patches:  2610
      [PATCHIFYING COMPLETED] output shape, slide:  (378, 256, 256, 3) 128 ; number of images:  42 , number of patches:  378
      [PATCHIFYING COMPLETED] output shape, slide:  (2610, 256, 256, 3) 128 ; number of images:  290 , number of patches:  2610
      [PATCHIFYING COMPLETED] output shape, slide:  (378, 256, 256, 3) 128 ; number of images:  42 , number of patches:  378


In [None]:
# tdata = data_gen.flow(*list(zip([training_data[:1000],training_data_labels[:1000]])))
# vdata = data_gen.flow(*list(zip([val_data[:100],val_data_labels[:100]])))

curr_loss = genSSIML1_3D_loss(alpha = config_train['alpha'])
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',verbose=True,factor=config_train['lr_decay_factor'],min_delta=0,patience=config_train['lr_decay_patience'],)

# TRAINING
model = build_model(kern_size=config_train['kern_size'],n_dim=2,n_first=config_train['first_depth'],n_channel_out = 1, n_depth=config_train['n_depth'],shape=(config_train['input_shape'][0],config_train['input_shape'][0],3),residual=False, batch_norm=False)
model = compile_model(model,config_train['learning_rate'], loss=curr_loss,metric=[{'psnr': psnr_3d_metric, 'ssim': ssim_3d_metric,'mae': mae_3d_metric}])


CUSTOM UNET INPUT SHAPE:  (256, 256, 3)
n_dim 2
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input (InputLayer)             [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 down_level_0_no_0 (Conv2D)     (None, 256, 256, 32  896         ['input[0][0]']                  
                                )                                                                 
                                                                                                  
 down_level_0_no_1 (Conv2D)     (None, 256, 256, 32  9248        ['down_level_0_no_0[0][0]']      
                                )             

In [None]:
history = model.fit(
                    training_data[:,:, :, :], training_data_labels[:,:, :, :],
                    epochs = config_train['epoch'],
                    batch_size=config_train['batch_size'],
                    validation_data = [val_data,val_data_labels]
                    , callbacks=[reduce_lr,tb_callback])


Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200
Epoch 50/200
Epoch 51/200
Epoch 52/200
Epoch 53/200
Epoch 54/200
Epoch 55/200
Epoch 56/200
Epoch 57/200
Epoch 58/200
Epoch 59/200
Epoch 60/200
Epoch 61/200
Epoch 62/200
Epoch 63/200
Epoch 63: ReduceLROnPlateau reducing learning rate to 9.899999749904965e-05.
Epoch 64/200
Epoch 65/200
Epoch 66/200
Epoch 67/200
Epoch 68/200
Epoch 68: ReduceLROnPlateau reducing learning ra

KeyboardInterrupt: ignored

In [None]:
# evaluating SSIM on merged images
result = mergeAndPredict(model,val_data,val_data_labels,0, config_train['img_size'], config_train['patch_size'], config_train['patch_size']/2)
evalSSIM(result)
result = mergeAndPredict(model,training_data,training_data_labels,0, config_train['img_size'], config_train['patch_size'], config_train['patch_size']/2)
evalSSIM(result)
print("END TRIAL =========================================")
print("")

In [None]:
from keras.callbacks import Callback, EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

# LEARNING RATE SEARCH EXPERIMENT
class MetricsCheckpoint(Callback):
    """Callback that saves metrics after each epoch"""
    def __init__(self, savepath):
        super(MetricsCheckpoint, self).__init__()
        self.savepath = savepath
        self.history = {}
    def on_epoch_end(self, epoch, logs=None):
        for k, v in logs.items():
            self.history.setdefault(k, []).append(v)
        np.save(self.savepath, self.history)

def plotKerasLearningCurve():
    plt.figure(figsize=(10,5))
    metrics = np.load('logs.npy')[()]
    filt = ['acc'] # try to add 'loss' to see the loss learning curve
    for k in filter(lambda x : np.any([kk in x for kk in filt]), metrics.keys()):
        l = np.array(metrics[k])
        plt.plot(l, c= 'r' if 'val' not in k else 'b', label='val' if 'val' in k else 'train')
        x = np.argmin(l) if 'loss' in k else np.argmax(l)
        y = l[x]
        plt.scatter(x,y, lw=0, alpha=0.25, s=100, c='r' if 'val' not in k else 'b')
        plt.text(x, y, '{} = {:.4f}'.format(x,y), size='15', color= 'r' if 'val' not in k else 'b')
    plt.legend(loc=4)
    plt.axis([0, None, None, None]);
    plt.grid()
    plt.xlabel('Number of epochs')

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.figure(figsize = (5,5))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

dict_characters = {0: '0', 1: '1', 2: '2',
        3: '3', 4: '4', 5: '5', 6: '6', 7:'7',
        8: '8', 9: '9'}

from matplotlib import pyplot as plt
import math
from keras.callbacks import LambdaCallback
import keras.backend as K


class LRFinder:
    """
    Plots the change of the loss function of a Keras model when the learning rate is exponentially increasing.
    See for details:
    https://towardsdatascience.com/estimating-optimal-learning-rate-for-a-deep-neural-network-ce32f2556ce0
    """
    def __init__(self, model):
        self.model = model
        self.losses = []
        self.lrs = []
        self.best_loss = 1e9

    def on_batch_end(self, batch, logs):
        # Log the learning rate
        lr = K.get_value(self.model.optimizer.lr)
        self.lrs.append(lr)

        # Log the loss
        loss = logs['loss']
        self.losses.append(loss)

        # Check whether the loss got too large or NaN
        if math.isnan(loss) or loss > self.best_loss * 4:
            self.model.stop_training = True
            return

        if loss < self.best_loss:
            self.best_loss = loss

        # Increase the learning rate for the next batch
        lr *= self.lr_mult
        K.set_value(self.model.optimizer.lr, lr)

    def find(self, x_train, y_train, start_lr, end_lr, batch_size=64, epochs=1):
        num_batches = epochs * x_train.shape[0] / batch_size
        self.lr_mult = (end_lr / start_lr) ** (1 / num_batches)

        # Save weights into a file
        self.model.save_weights('tmp.h5')

        # Remember the original learning rate
        original_lr = K.get_value(self.model.optimizer.lr)

        # Set the initial learning rate
        K.set_value(self.model.optimizer.lr, start_lr)

        callback = LambdaCallback(on_batch_end=lambda batch, logs: self.on_batch_end(batch, logs))

        self.model.fit(x_train, y_train,
                        batch_size=batch_size, epochs=epochs,
                        callbacks=[callback])

        # Restore the weights to the state before model fitting
        self.model.load_weights('tmp.h5')

        # Restore the original learning rate
        K.set_value(self.model.optimizer.lr, original_lr)

    def plot_loss(self, n_skip_beginning=10, n_skip_end=5):
        """
        Plots the loss.
        Parameters:
            n_skip_beginning - number of batches to skip on the left.
            n_skip_end - number of batches to skip on the right.
        """
        plt.ylabel("loss")
        plt.xlabel("learning rate (log scale)")
        plt.plot(self.lrs[n_skip_beginning:-n_skip_end], self.losses[n_skip_beginning:-n_skip_end])
        plt.xscale('log')
        for i in range(len(self.lrs)):
          print('lrs, losses',self.lrs,self.losses)


    def plot_loss_change(self, sma=1, n_skip_beginning=10, n_skip_end=5, y_lim=(-0.01, 0.01)):
        """
        Plots rate of change of the loss function.
        Parameters:
            sma - number of batches for simple moving average to smooth out the curve.
            n_skip_beginning - number of batches to skip on the left.
            n_skip_end - number of batches to skip on the right.
            y_lim - limits for the y axis.
        """
        assert sma >= 1
        derivatives = [0] * sma
        for i in range(sma, len(self.lrs)):
            derivative = (self.losses[i] - self.losses[i - sma]) / sma
            derivatives.append(derivative)

        plt.ylabel("rate of loss change")
        plt.xlabel("learning rate (log scale)")
        plt.plot(self.lrs[n_skip_beginning:-n_skip_end], derivatives[n_skip_beginning:-n_skip_end])
        plt.xscale('log')
        plt.ylim(y_lim)
        for i in range(len(self.lrs)):
          print('lrs, losses',self.lrs,self.losses)

In [None]:
lr_finder = LRFinder(model)
lr_finder.find(training_data, training_data_labels, start_lr=0.0000001, end_lr=100, batch_size=16, epochs=5)


In [None]:
lr_finder.plot_loss(n_skip_beginning=20, n_skip_end=5)
plt.show()

In [None]:
print(np.shape(expandLastDim(training_data)))

In [None]:
print(np.shape(training_data),np.shape(training_data_labels), np.shape(val_data), np.shape(val_data_labels))

In [None]:
# # print(np.amax(val_data[0]))
# plot(val_data[0],'test',hist=True)
# [training_dataa, val_dataa] = load_data(MAIN_PATH+DATA_PATH, patch=32)
# print("TEST AFTER", np.amax(training_dataa))
# test = normalize(merge_patches(val_dataa[0:961,:,:],512,32,16))
# plot(test,'test',hist=True)


In [None]:
# result = predictAndPlot(model,val_data,val_data_labels, 0, 961, isPlot=True, merge=True, normalize=False, clip=False)
# evalSSIM(result)

In [None]:
def revert_img(img,original_size, patch_shape, slide):
  # reverts original image and removes overlaps by splitting overlap over 2 images
  step = int(patch_shape-slide)
  reconstructed_arr = np.zeros((original_size,original_size))
  for x in range(img.shape[0]):
    for y in range(img.shape[1]):
      start_x = int(slide/2)
      start_y = int(slide/2)
      end_x = 0
      end_y = 0
      if x == 0:
        start_x = 0
        end_x = int(slide/2)
      if y == 0:
        start_y = 0
        end_y = int(slide/2)
      if x == img.shape[0]-1: end_x = int(slide/2)
      if y == img.shape[1]-1: end_y = int(slide/2)
      x_pos, y_pos = x * step + start_x, y * step + start_y
      # print('x_pos, start_x: ',x_pos, start_x)
      # print('x_pos + step + end_x, start_x+step+end_x: ',x_pos + step + end_x, start_x+step+end_x)
      # print('y_pos, start_y: ',y_pos, start_y)
      # print('y_pos + step + end_y,  start_y+step+end_y: ',y_pos + step + end_y,  start_y+step+end_y)
      # print("end===============x,y: ",x,y)
      reconstructed_arr[x_pos : x_pos + step + end_x, y_pos : y_pos + step + end_y] = img[x, y, start_x:start_x+step+end_x, start_y:start_y+step+end_y]
  return reconstructed_arr

def merge_patches(img, original_size, patch_shape, slide):
  #  merging patches, img is a 3D array of stacked patches
  print("merging patches, img shape: ", img.shape)
  row_len = int(math.sqrt(img.shape[0]))
  patches = np.zeros((row_len,row_len,patch_shape,patch_shape))
  print(img.shape)
  print(patches.shape)
  for r in range(row_len):
      patches[r,:,:,:] = img[r*row_len:r*row_len+row_len,:,:]
  # plt.figure(5)
  # plt.imshow(patches[1,1,:,:])
  return revert_img(patches,original_size,patch_shape, slide)


def MSE(img1, img2):
  # comparing one processed and preprocessed image
  squared_diff = (img1 -img2) ** 2
  summed = np.sum(squared_diff)
  num_pix = img1.shape[0] * img1.shape[1] #img1 and 2 should have same shape
  err = summed / num_pix
  return err

def normalize(img):
  # Normalizing images between 0 and 1 and preserving distribution
  img_norm = (img - np.amin(img))/( np.amax(img)- np.amin(img))
  return img_norm

def plot(img, label, hist=False, size=(10,10)):
  plt.figure(figsize = size)
  if hist==True:
    plt.hist(img.flatten(), bins=120)
  else: plt.imshow(img,cmap="gray")
  plt.title(label)

def normalize_between_zero_and_one(m):
    max_val, min_val = m.max(), m.min()
    diff = max_val - min_val
    return (m - min_val) / diff if diff > 0 else np.zeros_like(m)

def normalize_percentile(img, pmin=0.1, pmax=99.9, clip = True):
  eps=1e-20 # avoid zero division
  mi = np.percentile(img,pmin,axis=None,keepdims=True)
  # mi = np.amin(img)
  # print("mi",mi)
  ma = np.percentile(img,pmax,axis=None,keepdims=True)
  if clip == True: return np.clip((img - mi) / ( ma - mi + eps ), 0, 1)
  return (img - mi) / ( ma - mi + eps )

# def ssim(y_true, y_pred):
#     '''
#     Computes the structural similarity index between two images. Note that the
#     maximum signal value is assumed to be 1.
#     '''

#     return ssim2(y_true, y_pred,1,k2=0.05)
# our SSIM loss
from tensorflow.image import ssim as ssim2
def ssim_our(y_true, y_pred):
  return ssim2(y_true, y_pred,1,k2=0.05)

def _get_gaussian_kernel(dim, size, sigma):
    k = size // 2
    normal = tfp.distributions.Normal(0.0, sigma)
    p = normal.prob(tf.range(-k, size - k, dtype=tf.float32))

    indices = [chr(i) for i in range(105, 105 + dim)]
    eq = ','.join(indices) + '->' + ''.join(indices)
    kernel = tf.einsum(eq, *([p] * dim))
    kernel /= tf.reduce_sum(kernel)
    kernel = kernel[..., tf.newaxis, tf.newaxis]

    return kernel

# RCAN ssim
import keras.backend as K
def ssim_rcan(y_true, y_pred):
    '''
    Computes the structural similarity index between two images. Note that the
    maximum signal value is assumed to be 1.
    References
    ----------
    Image Quality Assessment: From Error Visibility to Structural Similarity
    https://doi.org/10.1109/TIP.2003.819861
    '''

    c1 = 0.01 ** 2
    c2 = 0.03 ** 2

    dim = K.ndim(y_pred) - 2
    if dim not in (2, 3):
        raise NotImplementedError(f'{dim}D SSIM is not suported')

    num_channels = K.int_shape(y_pred)[-1]

    kernel = _get_gaussian_kernel(dim, 11, 1.5)
    conv = K.conv2d if dim == 2 else K.conv3d

    def average(x):
        # channel-wise weighted average using the Gaussian kernel
        return tf.concat(
            [conv(y, kernel) for y in tf.split(x, num_channels, axis=-1)],
            axis=-1)

    ux = average(y_true)
    uy = average(y_pred)

    a = ux * uy
    b = K.square(ux) + K.square(uy)
    c = average(y_true * y_pred)
    d = average(K.square(y_true) + K.square(y_pred))

    lum = (2 * a + c1) / (b + c1)
    cs = (2 * (c - a) + c2) / (d - b + c2)

    return K.mean(K.batch_flatten(lum * cs), axis=-1)

def predictAndPlot(model,input,val, start, end, isPlot=True, merge=True, normalize=False, clip=False):
  pred = model.predict(input[start:end])
  print("shape of predicted image: ", np.shape(pred))
  if len(np.shape(input)) > 3:
    if merge:
      result = [merge_patches(input[start:end,:,:,0],512,64,0), merge_patches(pred[start:end,:,:,0],512,64,0), merge_patches(val[start:end,:,:,0],512,64,0)]
    else: result = [input[start:end,:,:,0], pred[start:end,:,:,0],val[start:end,:,:,0]]
  else:
    if merge:
      result = [merge_patches(input[start:end,:,:],512,64,0), merge_patches(pred[start:end,:,:,0],512,64,0), merge_patches(val[start:end,:,:],512,64,0)]
    else: result = [input[start:end,:,:], pred[start:end,:,:,0], val[start:end,:,:]]
  if normalize: result = [normalize_percentile(m) for m in result]
  if clip: result = [np.clip(255 * m, 0, 255).astype('uint8') for m in result]
  res_img =  np.concatenate((np.concatenate((result[0],result[1]),axis=1),result[2]),axis=1)
  plot(res_img, "input, restored, gt", size=(30,30))
  # plot(result[0],'input 25x')
  # plot(result[1],'restored')
  # plot(result[2],'gt 40x')
  if isPlot:
    plot(result[0],'hist input 25x',hist=True)
    plot(result[1],'hist restored',hist=True)
    plot(result[2],'hist gt 40x',hist=True)
  return result

def evalSSIM(result):
  result1 = [np.expand_dims(m,-1) for m in result]
  raw = tf.convert_to_tensor(np.expand_dims(result1[0],0),dtype=np.float32)
  rest = tf.convert_to_tensor(np.expand_dims(result1[1],0),dtype=np.float32)
  gt = tf.convert_to_tensor(np.expand_dims(result1[2],0),dtype=np.float32)
  # Metric evaluation
  print('raw vs gt=======================')
  print("Our SSIM between raw and ground truth: ",ssim_our(gt,raw).numpy())
  print("RCAN paper SSIM between raw and ground truth: ",ssim_rcan(gt, raw).numpy())
  print('predicted vs gt=======================')
  print("Our SSIM between predicted and ground truth: ",ssim_our(gt, rest).numpy())
  print("RCAN paper SSIM between predicted and ground truth: ",ssim_rcan(gt,rest).numpy())
  print('predicted vs raw=======================')
  print("Our SSIM between predicted and raw: ",ssim_our(raw, rest).numpy())
  print("RCAN paper SSIM between predicted and raw: ",ssim_rcan(raw, rest).numpy())

In [None]:
import tensorflow_probability as tfp
result = predictAndPlot(model,expandLastDim(val_data), expandLastDim(val_data_labels), 0, 64, isPlot=True, merge=True, normalize=False, clip=False)
evalSSIM(result)

In [None]:
print(np.shape(val_data))