# **DRMIME (2D)**

---

<font size = 4> DRMIME is a self-supervised deep-learning method that can be used to register 2D images.

<font size = 4> **This particular notebook enables self-supervised registration of 2D dataset.**

---

<font size = 4>*Disclaimer*:

<font size = 4>This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories. 


<font size = 4>While this notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (ZeroCostDL4Mic), this notebook structure substantially deviates from other ZeroCostDL4Mic notebooks and our template. This is because the deep learning method employed here is used to improve the image registration process. No Deep Learning models are actually saved, only the registered images. 


<font size = 4>This notebook is largely based on the following paper:

<font size = 4>DRMIME: Differentiable Mutual Information and Matrix Exponential for Multi-Resolution Image Registration by Abhishek Nan
 *et al.* published on arXiv in 2020 (https://arxiv.org/abs/2001.09865)

<font size = 4>And source code found in: https://github.com/abnan/DRMIME

<font size = 4>**Please also cite this original paper when using or developing this notebook.**


# **How to use this notebook?**

---

<font size = 4>Video describing how to use our notebooks are available on youtube:
  - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook
  - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook


---
###**Structure of a notebook**

<font size = 4>The notebook contains two types of cell:  

<font size = 4>**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.

<font size = 4>**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.

---
###**Table of contents, Code snippets** and **Files**

<font size = 4>On the top left side of the notebook you find three tabs which contain from top to bottom:

<font size = 4>*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.

<font size = 4>*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.

<font size = 4>*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. 

<font size = 4>**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.

<font size = 4>**Note:** The "sample data" in "Files" contains default files. Do not upload anything in here!

---
###**Making changes to the notebook**

<font size = 4>**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.

<font size = 4>To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).
You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment.

# **0. Before getting started**
---

<font size = 4>Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.

<font size = 4>For DRMIME to train, it requires at least two images. One **`"Fixed image"`** (template for the registration) and one **`Moving Image`** (image to be registered). Multiple **`Moving Images`** can also be provided if you want to register them to the same **`"Fixed image"`**. If you provide several **`Moving Images`**, multiple DRMIME instances will run one after another.   

<font size = 4>The registration can also be applied to other channels. If you wish to apply the registration to other channels, please provide the images in another folder and carefully check your file names. Additional channels need to have the same name as the registered images and a prefix indicating the channel number starting at "C1_". See the example below.   

<font size = 4>Here is a common data structure that can work:

*   Data
    
    - **Fixed_image_folder**
      - img_1.tif (image used as template for the registration)
    - **Moving_image_folder**
     - img_3.tif, img_4.tif, ... (images to be registered)   
    - **Folder_containing_additional_channels** (optional, if you want to apply the registration to other channel(s))
      - C1_img_3.tif, C1_img_4.tif, ...
      - C2_img_3.tif, C2_img_4.tif, ...
      - C3_img_3.tif, C3_img_4.tif, ...
    - **Results**

<font size = 4>The **Results** folder will contain the processed images and PDF reports. Your original images remain unmodified.

---



# **1. Initialise the Colab session**




---






## **1.1. Check for GPU access**
---

By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelator: GPU** *(Graphics processing unit)*


In [None]:
#@markdown ##Run this cell to check if you have GPU access
#%tensorflow_version 1.x


import tensorflow as tf
if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi

## **1.2. Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. 

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [None]:
#@markdown ##Play the cell to connect your Google Drive to Colab

#@markdown * Click on the URL. 

#@markdown * Sign in your Google Account. 

#@markdown * Copy the authorization code. 

#@markdown * Enter the authorization code. 

#@markdown * Click on "Files" site on the right. Refresh the site. Your Google Drive folder should now be available here as "drive". 

# mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/gdrive')

# **2. Install DRMIME and dependencies**
---

In [None]:
Notebook_version = ['1.12']



#@markdown ##Install DRMIME and dependencies


# Here we install DRMIME and other required packages

!pip install wget

from skimage import io
import numpy as np
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from skimage.transform import pyramid_gaussian
from skimage.filters import gaussian
from skimage.filters import threshold_otsu
from skimage.filters import sobel
from skimage.color import rgb2gray
from skimage import feature
from torch.autograd import Function
import cv2
from IPython.display import clear_output
import pandas as pd
from skimage.io import imsave



device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



# ------- Common variable to all ZeroCostDL4Mic notebooks -------

import urllib
import os, random
import shutil 
import zipfile
from tifffile import imread, imsave
import time
import sys
import wget
from pathlib import Path
import pandas as pd
import csv
from glob import glob
from scipy import signal
from scipy import ndimage
from skimage import io
from sklearn.linear_model import LinearRegression
from skimage.util import img_as_uint
import matplotlib as mpl
from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio as psnr
from astropy.visualization import simple_norm
from skimage import img_as_float32

# Colors for the warning messages
class bcolors:
  WARNING = '\033[31m'
W  = '\033[0m'  # white (normal)
R  = '\033[31m' # red

#Disable some of the tensorflow warnings
import warnings
warnings.filterwarnings("ignore")


# Check if this is the latest version of the notebook
Latest_notebook_version = pd.read_csv("https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv")

if Notebook_version == list(Latest_notebook_version.columns):
  print("This notebook is up-to-date.")

if not Notebook_version == list(Latest_notebook_version.columns):
  print(bcolors.WARNING +"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki")

!pip freeze > requirements.txt

#Create a pdf document with training summary, not yet implemented

def pdf_export(trained = False, augmentation = False, pretrained_model = False):
    # save FPDF() class into a  
    # variable pdf 
    #from datetime import datetime

    class MyFPDF(FPDF, HTMLMixin):
        pass

    pdf = MyFPDF()
    pdf.add_page()
    pdf.set_right_margin(-1)
    pdf.set_font("Arial", size = 11, style='B') 

    Network = 'CARE 2D'
    day = datetime.now()
    datetime_str = str(day)[0:10]

    Header = 'Training report for '+Network+' model ('+model_name+')\nDate: '+datetime_str
    pdf.multi_cell(180, 5, txt = Header, align = 'L') 

    # add another cell 
    if trained:
      training_time = "Training time: "+str(hour)+ "hour(s) "+str(mins)+"min(s) "+str(round(sec))+"sec(s)"
      pdf.cell(190, 5, txt = training_time, ln = 1, align='L')
    pdf.ln(1)

    Header_2 = 'Information for your materials and methods:'
    pdf.cell(190, 5, txt=Header_2, ln=1, align='L')

    all_packages = ''
    for requirement in freeze(local_only=True):
      all_packages = all_packages+requirement+', '
    #print(all_packages)

    #Main Packages
    main_packages = ''
    version_numbers = []
    for name in ['tensorflow','numpy','Keras','csbdeep']:
      find_name=all_packages.find(name)
      main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '
      #Version numbers only here:
      version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])

    cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)
    cuda_version = cuda_version.stdout.decode('utf-8')
    cuda_version = cuda_version[cuda_version.find(', V')+3:-1]
    gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)
    gpu_name = gpu_name.stdout.decode('utf-8')
    gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]
    #print(cuda_version[cuda_version.find(', V')+3:-1])
    #print(gpu_name)

    shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape
    dataset_size = len(os.listdir(Training_source))

    text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'

    if pretrained_model:
      text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'

    pdf.set_font('')
    pdf.set_font_size(10.)
    pdf.multi_cell(190, 5, txt = text, align='L')
    pdf.set_font('')
    pdf.set_font('Arial', size = 10, style = 'B')
    pdf.ln(1)
    pdf.cell(28, 5, txt='Augmentation: ', ln=0)
    pdf.set_font('')
    if augmentation:
      aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'
      if rotate_270_degrees != 0 or rotate_90_degrees != 0:
        aug_text = aug_text+'\n- rotation'
      if flip_left_right != 0 or flip_top_bottom != 0:
        aug_text = aug_text+'\n- flipping'
      if random_zoom_magnification != 0:
        aug_text = aug_text+'\n- random zoom magnification'
      if random_distortion != 0:
        aug_text = aug_text+'\n- random distortion'
      if image_shear != 0:
        aug_text = aug_text+'\n- image shearing'
      if skew_image != 0:
        aug_text = aug_text+'\n- image skewing'
    else:
      aug_text = 'No augmentation was used for training.'
    pdf.multi_cell(190, 5, txt=aug_text, align='L')
    pdf.set_font('Arial', size = 11, style = 'B')
    pdf.ln(1)
    pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)
    pdf.set_font('')
    pdf.set_font_size(10.)
    if Use_Default_Advanced_Parameters:
      pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')
    pdf.cell(200, 5, txt='The following parameters were used for training:')
    pdf.ln(1)
    html = """ 
    <table width=40% style="margin-left:0px;">
      <tr>
        <th width = 50% align="left">Parameter</th>
        <th width = 50% align="left">Value</th>
      </tr>
      <tr>
        <td width = 50%>number_of_epochs</td>
        <td width = 50%>{0}</td>
      </tr>
      <tr>
        <td width = 50%>patch_size</td>
        <td width = 50%>{1}</td>
      </tr>
      <tr>
        <td width = 50%>number_of_patches</td>
        <td width = 50%>{2}</td>
      </tr>
      <tr>
        <td width = 50%>batch_size</td>
        <td width = 50%>{3}</td>
      </tr>
      <tr>
        <td width = 50%>number_of_steps</td>
        <td width = 50%>{4}</td>
      </tr>
      <tr>
        <td width = 50%>percentage_validation</td>
        <td width = 50%>{5}</td>
      </tr>
      <tr>
        <td width = 50%>initial_learning_rate</td>
        <td width = 50%>{6}</td>
      </tr>
    </table>
    """.format(number_of_epochs,str(patch_size)+'x'+str(patch_size),number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)
    pdf.write_html(html)

    #pdf.multi_cell(190, 5, txt = text_2, align='L')
    pdf.set_font("Arial", size = 11, style='B')
    pdf.ln(1)
    pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)
    pdf.set_font('')
    pdf.set_font('Arial', size = 10, style = 'B')
    pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)
    pdf.set_font('')
    pdf.multi_cell(170, 5, txt = Training_source, align = 'L')
    pdf.set_font('')
    pdf.set_font('Arial', size = 10, style = 'B')
    pdf.cell(27, 5, txt= 'Training_target:', align = 'L', ln=0)
    pdf.set_font('')
    pdf.multi_cell(170, 5, txt = Training_target, align = 'L')
    #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)
    pdf.ln(1)
    pdf.set_font('')
    pdf.set_font('Arial', size = 10, style = 'B')
    pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)
    pdf.set_font('')
    pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')
    pdf.ln(1)
    pdf.cell(60, 5, txt = 'Example Training pair', ln=1)
    pdf.ln(1)
    exp_size = io.imread('/content/TrainingDataExample_CARE2D.png').shape
    pdf.image('/content/TrainingDataExample_CARE2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))
    pdf.ln(1)
    ref_1 = 'References:\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. "ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy." BioRxiv (2020).'
    pdf.multi_cell(190, 5, txt = ref_1, align='L')
    ref_2 = '- CARE: Weigert, Martin, et al. "Content-aware image restoration: pushing the limits of fluorescence microscopy." Nature methods 15.12 (2018): 1090-1097.'
    pdf.multi_cell(190, 5, txt = ref_2, align='L')
    if augmentation:
      ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. "Augmentor: an image augmentation library for machine learning." arXiv preprint arXiv:1708.04680 (2017).'
      pdf.multi_cell(190, 5, txt = ref_3, align='L')
    pdf.ln(3)
    reminder = 'Important:\nRemember to perform the quality control step on all newly trained models\nPlease consider depositing your training dataset on Zenodo'
    pdf.set_font('Arial', size = 11, style='B')
    pdf.multi_cell(190, 5, txt=reminder, align='C')

    pdf.output(model_path+'/'+model_name+'/'+model_name+"_training_report.pdf")




print("Libraries installed")


# **3. Select your parameters and paths**
---

## **3.1. Setting main training parameters**
---
<font size = 4> 

<font size = 4> **Paths for training, predictions and results**
These is the path to your folders containing the image you want to register. To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.

<font size = 4>**`Fixed_image_folder`:** This is the folder containing your "Fixed image".

<font size = 4>**`Moving_image_folder`:** This is the folder containing your "Moving Image(s)".

<font size = 4>**`Result_folder`:** This is the folder where your results will be saved.


<font size = 5>**Training Parameters**

<font size = 4>**`model_name`:** Choose a name for your model.

<font size = 4>**`number_of_iteration`:** Input how many iteration (rounds) the network will be trained. Preliminary results can already be observed after a 200 iterations, but a full training should run for 500-1000 iterations. **Default value: 500**

<font size = 4>**`Registration_mode`:** Choose which registration method you would like to use.

<font size = 5>**Additional channels**

<font size = 4> This option enable you to apply the registration to other images (for instance other channels). Place these images in the **`Additional_channels_folder`**. Additional channels need to have the same name as the images you want to register (found in **`Moving_image_folder`**) and a prefix indicating the channel number starting at "C1_".

    
<font size = 5>**Advanced Parameters - experienced users only**

<font size = 4>**`n_neurons`:**  Number of neurons (elementary constituents) that will assemble your model. **Default value: 100**.

<font size = 4>**`mine_initial_learning_rate`:** Input the initial value to be used as learning rate for MINE. **Default value: 0.001**
<font size = 4>**`homography_net_vL_initial_learning_rate`:** Input the initial value to be used as learning rate for homography_net_vL. **Default value: 0.001**

<font size = 4>**`homography_net_v1_initial_learning_rate`:** Input the initial value to be used as learning rate for homography_net_v1. **Default value: 0.0001**


In [None]:

#@markdown ###Path to the Fixed and Moving image folders: 
Fixed_image_folder = "" #@param {type:"string"}


import os.path
from os import path

if path.isfile(Fixed_image_folder):
  I = imread(Fixed_image_folder).astype(np.float32) # fixed image

if path.isdir(Fixed_image_folder):
  Fixed_image = os.listdir(Fixed_image_folder)
  I = imread(Fixed_image_folder+"/"+Fixed_image[0]).astype(np.float32) # fixed image


Moving_image_folder = "" #@param {type:"string"}

#@markdown ### Provide the path to the folder where the predictions are to be saved
Result_folder = "" #@param {type:"string"}


#@markdown ###Training Parameters
model_name = "" #@param {type:"string"}

number_of_iteration =  500#@param {type:"number"}

Registration_mode = "Affine" #@param ["Affine", "Perspective"]


#@markdown ###Do you want to apply the registration to other channel(s)?
Apply_registration_to_other_channels = False#@param {type:"boolean"}

Additional_channels_folder = "" #@param {type:"string"}

#@markdown ###Advanced Parameters

Use_Default_Advanced_Parameters = True#@param {type:"boolean"}

#@markdown ###If not, please input:

n_neurons = 100 #@param {type:"number"}
mine_initial_learning_rate = 0.001 #@param {type:"number"}
homography_net_vL_initial_learning_rate = 0.001 #@param {type:"number"}
homography_net_v1_initial_learning_rate = 0.0001 #@param {type:"number"}

if (Use_Default_Advanced_Parameters): 
  print("Default advanced parameters enabled")  
  n_neurons = 100
  mine_initial_learning_rate = 0.001
  homography_net_vL_initial_learning_rate = 0.001
  homography_net_v1_initial_learning_rate = 0.0001


#failsafe for downscale could be useful  
#to be added


#Load a random moving image to visualise and test the settings
random_choice = random.choice(os.listdir(Moving_image_folder))
J = imread(Moving_image_folder+"/"+random_choice).astype(np.float32)

# Check if additional channel(s) need to be registered and if so how many

print(str(len(os.listdir(Moving_image_folder)))+" image(s) will be registered.")

if Apply_registration_to_other_channels:

  other_channel_images = os.listdir(Additional_channels_folder)
  Number_of_other_channels = len(other_channel_images)/len(os.listdir(Moving_image_folder))

  if Number_of_other_channels.is_integer():
    print("The registration(s) will be propagated to "+str(Number_of_other_channels)+" other channel(s)")
  else:
    print(bcolors.WARNING +"!! WARNING: Incorrect number of images in Folder_containing_additional_channels"+W)

#here we check that no model with the same name already exist, if so print a warning
if os.path.exists(Result_folder+'/'+model_name):
  print(bcolors.WARNING +"!! WARNING: "+model_name+" already exists and will be deleted in the following cell !!")
  print(bcolors.WARNING +"To continue training "+model_name+", choose a new model_name here, and load "+model_name+" in section 3.3"+W)
  

print("Example of two images to be registered")

#Here we display one image
f=plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest')


plt.title('Fixed image')
plt.axis('off');

plt.subplot(1,2,2)
plt.imshow(J, norm=simple_norm(J, percent = 99), interpolation='nearest')
plt.title('Moving image')
plt.axis('off');
plt.savefig('/content/TrainingDataExample_DRMIME2D.png',bbox_inches='tight',pad_inches=0)
plt.show()



## **3.2. Choose and test the image pre-processing settings**
---
<font size = 4> DRMIME makes use of multi-resolution image pyramids to perform registration. Unlike a conventional method where computation starts at the highest level of the image pyramid and gradually proceeds to the lower levels, DRMIME simultaneously use all the levels in gradient descent-based optimization using automatic differentiation. Here, you can choose the parameters that define the multi-resolution image pyramids that will be used.

<font size = 4>**`nb_images_pyramid`:** Choose the number of images to use to assemble the pyramid. **Default value: 10**.

<font size = 4>**`Level_downscaling`:** Choose the level of downscaling that will be used to create the images of the pyramid **Default value: 1.8**.

<font size = 4>**`sampling`:** amount of sampling used for the perspective registration. **Default value: 0.1**.



In [None]:

#@markdown ##Image pre-processing settings

nb_images_pyramid = 10#@param {type:"number"}  # where registration starts (at the coarsest resolution)

L = nb_images_pyramid

Level_downscaling = 1.8#@param {type:"number"}

downscale = Level_downscaling

sampling = 0.1#@param {type:"number"} # 10% sampling used only for perspective registration


ifplot=True
if np.ndim(I) == 3:
    nChannel=I.shape[2]
    pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=True), downscale=downscale, multichannel=True))
    pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=True), downscale=downscale, multichannel=True))
elif np.ndim(I) == 2:
    nChannel=1
    pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=False), downscale=downscale, multichannel=False))
    pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=False), downscale=downscale, multichannel=False))
else:
    print("Unknown rank for an image")


# Control the display
width=5
height=5
rows = int(L/5)+1
cols = 5
axes=[]
fig=plt.figure(figsize=(16,16))

if Registration_mode == "Affine":

  print("Affine registration selected")

# create a list of necessary objects you will need and commit to GPU
  I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]
  for s in range(L):
      I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)
      J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)

      if nChannel>1:
          I_lst.append(I_.permute(2,0,1))
          J_lst.append(J_.permute(2,0,1))
          h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]

          edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),
                                  np.ones((5,5),np.uint8),
                                  iterations = 1)
          ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]
          ind_lst.append(ind_)
      else:
          I_lst.append(I_)
          J_lst.append(J_)
          h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]

          edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),
                                  np.ones((5,5),np.uint8),
                                  iterations = 1)
          ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]
          ind_lst.append(ind_)  
        
      axes.append( fig.add_subplot(rows, cols, s+1) )
      subplot_title=(str(s))
      axes[-1].set_title(subplot_title)  
      plt.imshow(edges_grayscale)
      plt.axis('off');

      h_lst.append(h_)
      w_lst.append(w_)

      y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])
      y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0
      xy_ = torch.stack([x_,y_],2)
      xy_lst.append(xy_)

  fig.tight_layout()

  plt.show()


if Registration_mode == "Perspective":

  print("Perspective registration selected")

# create a list of necessary objects you will need and commit to GPU
  I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]
  for s in range(L):
      I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)
      J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)
    
      if nChannel>1:
          I_lst.append(I_.permute(2,0,1))
          J_lst.append(J_.permute(2,0,1))
          h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]

          ind_ = torch.randperm(int(h_*w_*sampling))
          ind_lst.append(ind_)
      else:
          I_lst.append(I_)
          J_lst.append(J_)
          h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]

          edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 10),
                                  np.ones((5,5),np.uint8),
                                  iterations = 1)
          ind_ = torch.randperm(int(h_*w_*sampling))
          ind_lst.append(ind_) 
        
      axes.append( fig.add_subplot(rows, cols, s+1) )
      subplot_title=(str(s))
      axes[-1].set_title(subplot_title)  
      plt.imshow(edges_grayscale)
      plt.axis('off');

      h_lst.append(h_)
      w_lst.append(w_)

      y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])
      y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0
      xy_ = torch.stack([x_,y_],2)
      xy_lst.append(xy_)

  fig.tight_layout()

  plt.show()


# **4. Train the network**
---

## **4.1. Prepare for training**
---
<font size = 4>Here, we use the information from 3. to load the correct dependencies.

In [None]:
#@markdown ##Load the dependencies required for training

print("--------------------------------------------------")

# Remove the model name folder if exists

if os.path.exists(Result_folder+'/'+model_name):
  print(bcolors.WARNING +"!! WARNING: Model folder already exists and has been removed !!"+W)
  shutil.rmtree(Result_folder+'/'+model_name)
os.makedirs(Result_folder+'/'+model_name)



if Registration_mode == "Affine":

  class HomographyNet(nn.Module):
      def __init__(self):
          super(HomographyNet, self).__init__()
          # affine transform basis matrices

          self.B = torch.zeros(6,3,3).to(device)
          self.B[0,0,2] = 1.0
          self.B[1,1,2] = 1.0
          self.B[2,0,1] = 1.0
          self.B[3,1,0] = 1.0
          self.B[4,0,0], self.B[4,1,1] = 1.0, -1.0
          self.B[5,1,1], self.B[5,2,2] = -1.0, 1.0

          self.v1 = torch.nn.Parameter(torch.zeros(6,1,1).to(device), requires_grad=True)
          self.vL = torch.nn.Parameter(torch.zeros(6,1,1).to(device), requires_grad=True)

      def forward(self, s):
          C = torch.sum(self.B*self.vL,0)
          if s==0:
              C += torch.sum(self.B*self.v1,0)
          A = torch.eye(3).to(device)
          H = A
          for i in torch.arange(1,10):
              A = torch.mm(A/i,C)
              H = H + A
          return H

  class MINE(nn.Module): #https://arxiv.org/abs/1801.04062
      def __init__(self):
          super(MINE, self).__init__()
          self.fc1 = nn.Linear(2*nChannel, n_neurons)
          self.fc2 = nn.Linear(n_neurons, n_neurons)
          self.fc3 = nn.Linear(n_neurons, 1)
          self.bsize = 1 # 1 may be sufficient

      def forward(self, x, ind):
          x = x.view(x.size()[0]*x.size()[1],x.size()[2])
          MI_lb=0.0
          for i in range(self.bsize):
              ind_perm = ind[torch.randperm(len(ind))]
              z1 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(x[ind,:])))))
              z2 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(torch.cat((x[ind,0:nChannel],x[ind_perm,nChannel:2*nChannel]),1))))))
              MI_lb += torch.mean(z1) - torch.log(torch.mean(torch.exp(z2)))

          return MI_lb/self.bsize

  def AffineTransform(I, H, xv, yv):
    # apply affine transform
      xvt = (xv*H[0,0]+yv*H[0,1]+H[0,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])
      yvt = (xv*H[1,0]+yv*H[1,1]+H[1,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])
      J = F.grid_sample(I,torch.stack([xvt,yvt],2).unsqueeze(0)).squeeze()
      return J


  def multi_resolution_loss():
      loss=0.0
      for s in np.arange(L-1,-1,-1):
          if nChannel>1:
              Jw_ = AffineTransform(J_lst[s].unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()
              mi = mine_net(torch.cat([I_lst[s],Jw_],0).permute(1,2,0),ind_lst[s])
              loss = loss - (1./L)*mi
          else:
              Jw_ = AffineTransform(J_lst[s].unsqueeze(0).unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()
              mi = mine_net(torch.stack([I_lst[s],Jw_],2),ind_lst[s])
              loss = loss - (1./L)*mi

      return loss



if Registration_mode == "Perspective":

  class HomographyNet(nn.Module):
      def __init__(self):
          super(HomographyNet, self).__init__()
        # affine transform basis matrices

          self.B = torch.zeros(8,3,3).to(device)
          self.B[0,0,2] = 1.0
          self.B[1,1,2] = 1.0
          self.B[2,0,1] = 1.0
          self.B[3,1,0] = 1.0
          self.B[4,0,0], self.B[4,1,1] = 1.0, -1.0
          self.B[5,1,1], self.B[5,2,2] = -1.0, 1.0
          self.B[6,2,0] = 1.0
          self.B[7,2,1] = 1.0

          self.v1 = torch.nn.Parameter(torch.zeros(8,1,1).to(device), requires_grad=True)
          self.vL = torch.nn.Parameter(torch.zeros(8,1,1).to(device), requires_grad=True)

      def forward(self, s):
          C = torch.sum(self.B*self.vL,0)
          if s==0:
              C += torch.sum(self.B*self.v1,0)
          A = torch.eye(3).to(device)
          H = A
          for i in torch.arange(1,10):
              A = torch.mm(A/i,C)
              H = H + A
          return H


  class MINE(nn.Module): #https://arxiv.org/abs/1801.04062
      def __init__(self):
          super(MINE, self).__init__()
          self.fc1 = nn.Linear(2*nChannel, n_neurons)
          self.fc2 = nn.Linear(n_neurons, n_neurons)
          self.fc3 = nn.Linear(n_neurons, 1)
          self.bsize = 1 # 1 may be sufficient

      def forward(self, x, ind):
          x = x.view(x.size()[0]*x.size()[1],x.size()[2])
          MI_lb=0.0
          for i in range(self.bsize):
              ind_perm = ind[torch.randperm(len(ind))]
              z1 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(x[ind,:])))))
              z2 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(torch.cat((x[ind,0:nChannel],x[ind_perm,nChannel:2*nChannel]),1))))))
              MI_lb += torch.mean(z1) - torch.log(torch.mean(torch.exp(z2)))

          return MI_lb/self.bsize


  def PerspectiveTransform(I, H, xv, yv):
    # apply homography
      xvt = (xv*H[0,0]+yv*H[0,1]+H[0,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])
      yvt = (xv*H[1,0]+yv*H[1,1]+H[1,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])
      J = F.grid_sample(I,torch.stack([xvt,yvt],2).unsqueeze(0)).squeeze()
      return J


  def multi_resolution_loss():
      loss=0.0
      for s in np.arange(L-1,-1,-1):
          if nChannel>1:
              Jw_ = PerspectiveTransform(J_lst[s].unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()
              mi = mine_net(torch.cat([I_lst[s],Jw_],0).permute(1,2,0),ind_lst[s])
              loss = loss - (1./L)*mi
          else:
              Jw_ = PerspectiveTransform(J_lst[s].unsqueeze(0).unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()
              mi = mine_net(torch.stack([I_lst[s],Jw_],2),ind_lst[s])
              loss = loss - (1./L)*mi

      return loss

  def histogram_mutual_information(image1, image2):
      hgram, x_edges, y_edges = np.histogram2d(image1.ravel(), image2.ravel(), bins=100)
      pxy = hgram / float(np.sum(hgram))
      px = np.sum(pxy, axis=1)
      py = np.sum(pxy, axis=0)
      px_py = px[:, None] * py[None, :]
      nzs = pxy > 0
      return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))


print("Done")


## **4.2. Start Trainning**
---
<font size = 4>When playing the cell below you should see updates after each iterations (round). A new network will be trained for each image that need to be registered.

<font size = 4>* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.



In [None]:
#@markdown ##Start training and the registration process

start = time.time()

loop_number = 1



if Registration_mode == "Affine":

  print("Affine registration.....")

  for image in os.listdir(Moving_image_folder):

    if path.isfile(Fixed_image_folder):
      I = imread(Fixed_image_folder).astype(np.float32) # fixed image

    if path.isdir(Fixed_image_folder):
      Fixed_image = os.listdir(Fixed_image_folder)
      I = imread(Fixed_image_folder+"/"+Fixed_image[0]).astype(np.float32) # fixed image

    J = imread(Moving_image_folder+"/"+image).astype(np.float32)

  # Here we generate the pyramidal images
    ifplot=True
    if np.ndim(I) == 3:
      nChannel=I.shape[2]
      pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=True), downscale=downscale, multichannel=True))
      pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=True), downscale=downscale, multichannel=True))
    elif np.ndim(I) == 2:
      nChannel=1
      pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=False), downscale=downscale, multichannel=False))
      pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=False), downscale=downscale, multichannel=False))
    else:
      print("Unknown rank for an image")


  # create a list of necessary objects you will need and commit to GPU
    I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]


    for s in range(L):
        I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)
        J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)

        if nChannel>1:
            I_lst.append(I_.permute(2,0,1))
            J_lst.append(J_.permute(2,0,1))
            h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]

            edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),
                                    np.ones((5,5),np.uint8),
                                    iterations = 1)
            ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]
            ind_lst.append(ind_)
        else:
            I_lst.append(I_)
            J_lst.append(J_)
            h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]

            edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),
                                    np.ones((5,5),np.uint8),
                                    iterations = 1)
            ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]
            ind_lst.append(ind_)

        h_lst.append(h_)
        w_lst.append(w_)

        y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])
        y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0
        xy_ = torch.stack([x_,y_],2)
        xy_lst.append(xy_)

    homography_net = HomographyNet().to(device)
    mine_net = MINE().to(device)

    optimizer = optim.Adam([{'params': mine_net.parameters(), 'lr': 1e-3},
                      {'params': homography_net.vL, 'lr': 5e-3},
                      {'params': homography_net.v1, 'lr': 1e-4}], amsgrad=True)
    mi_list = []
    for itr in range(number_of_iteration):
        optimizer.zero_grad()
        loss = multi_resolution_loss()
        mi_list.append(-loss.item())
        loss.backward()
        optimizer.step()
        clear_output(wait=True)
        plt.plot(mi_list)
        plt.xlabel('Iteration number')
        plt.ylabel('MI')
        plt.title(image+". Image registration "+str(loop_number)+" out of "+str(len(os.listdir(Moving_image_folder)))+".")
        plt.show()

    I_t = torch.tensor(I).to(device) # without Gaussian
    J_t = torch.tensor(J).to(device) # without Gaussian
    H = homography_net(0)
    if nChannel>1:
        J_w = AffineTransform(J_t.permute(2,0,1).unsqueeze(0), H, xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze().permute(1,2,0)
    else:
        J_w = AffineTransform(J_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()

      #Apply registration to other channels

        if Apply_registration_to_other_channels:

          for n_channel in range(1, int(Number_of_other_channels)+1):

            channel = imread(Additional_channels_folder+"/C"+str(n_channel)+"_"+image).astype(np.float32)
            channel_t = torch.tensor(channel).to(device)
            channel_w = AffineTransform(channel_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()
            channel_registered = channel_w.cpu().data.numpy()
            io.imsave(Result_folder+'/'+model_name+"/"+"C"+str(n_channel)+"_"+image+"_"+Registration_mode+"_registered.tif", channel_registered)
            
# Export results to numpy array
    registered = J_w.cpu().data.numpy()
# Save results
    io.imsave(Result_folder+'/'+model_name+"/"+image+"_"+Registration_mode+"_registered.tif", registered)

    loop_number = loop_number + 1

  print("Your images have been registered and saved in your result_folder")


#Perspective registration

if Registration_mode == "Perspective":

  print("Perspective registration.....")

  for image in os.listdir(Moving_image_folder):

    if path.isfile(Fixed_image_folder):
      I = imread(Fixed_image_folder).astype(np.float32) # fixed image

    if path.isdir(Fixed_image_folder):
      Fixed_image = os.listdir(Fixed_image_folder)
      I = imread(Fixed_image).astype(np.float32) # fixed image

    J = imread(Moving_image_folder+"/"+image).astype(np.float32)

  # Here we generate the pyramidal images
    ifplot=True
    if np.ndim(I) == 3:
      nChannel=I.shape[2]
      pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=True), downscale=downscale, multichannel=True))
      pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=True), downscale=downscale, multichannel=True))
    elif np.ndim(I) == 2:
      nChannel=1
      pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=False), downscale=downscale, multichannel=False))
      pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=False), downscale=downscale, multichannel=False))
    else:
      print("Unknown rank for an image")


  # create a list of necessary objects you will need and commit to GPU
    I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]
    for s in range(L):
        I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)
        J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)

        if nChannel>1:
            I_lst.append(I_.permute(2,0,1))
            J_lst.append(J_.permute(2,0,1))
            h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]

            ind_ = torch.randperm(int(h_*w_*sampling))
            ind_lst.append(ind_)
        else:
            I_lst.append(I_)
            J_lst.append(J_)
            h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]

            edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 10),
                                    np.ones((5,5),np.uint8),
                                    iterations = 1)
            ind_ = torch.randperm(int(h_*w_*sampling))
            ind_lst.append(ind_)
        h_lst.append(h_)
        w_lst.append(w_)

        y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])
        y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0
        xy_ = torch.stack([x_,y_],2)
        xy_lst.append(xy_)

    homography_net = HomographyNet().to(device)
    mine_net = MINE().to(device)

    optimizer = optim.Adam([{'params': mine_net.parameters(), 'lr': 1e-3},
                    {'params': homography_net.vL, 'lr': 1e-3},
                    {'params': homography_net.v1, 'lr': 1e-4}], amsgrad=True)
    mi_list = []
    for itr in range(number_of_iteration):
        optimizer.zero_grad()
        loss = multi_resolution_loss()
        mi_list.append(-loss.item())
        loss.backward()
        optimizer.step()
        clear_output(wait=True)
        plt.plot(mi_list)
        plt.xlabel('Iteration number')
        plt.ylabel('MI')
        plt.title(image+". Image registration "+str(loop_number)+" out of "+str(len(os.listdir(Moving_image_folder)))+".")
        plt.show()

    I_t = torch.tensor(I).to(device) # without Gaussian
    J_t = torch.tensor(J).to(device) # without Gaussian
    H = homography_net(0)
    if nChannel>1:
        J_w = PerspectiveTransform(J_t.permute(2,0,1).unsqueeze(0), H, xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze().permute(1,2,0)
    else:
        J_w = PerspectiveTransform(J_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()

      #Apply registration to other channels

        if Apply_registration_to_other_channels:

          for n_channel in range(1, int(Number_of_other_channels)+1):

            channel = imread(Additional_channels_folder+"/C"+str(n_channel)+"_"+image).astype(np.float32)
            channel_t = torch.tensor(channel).to(device)
            channel_w = PerspectiveTransform(channel_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()
            channel_registered = channel_w.cpu().data.numpy()
            io.imsave(Result_folder+'/'+model_name+"/"+"C"+str(n_channel)+"_"+image+"_Perspective_registered.tif", channel_registered)      


# Export results to numpy array
    registered = J_w.cpu().data.numpy()
# Save results
    io.imsave(Result_folder+'/'+model_name+"/"+image+"_Perspective_registered.tif", registered)

    loop_number = loop_number + 1

  print("Your images have been registered and saved in your result_folder")


# PDF export missing 

#pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)



## **4.3. Assess the registration**
---




In [None]:
# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.

# For sliders and dropdown menu and progress bar
from ipywidgets import interact
import ipywidgets as widgets

print('--------------------------------------------------------------')
@interact
def show_QC_results(file = os.listdir(Moving_image_folder)):

  moving_image = imread(Moving_image_folder+"/"+file).astype(np.float32)
  
  registered_image = imread(Result_folder+"/"+model_name+"/"+file+"_"+Registration_mode+"_registered.tif").astype(np.float32)

#Here we display one image

  f=plt.figure(figsize=(20,20))
  plt.subplot(1,5,1)
  plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest')
  plt.title('Fixed image')
  plt.axis('off');

  plt.subplot(1,5,2)
  plt.imshow(moving_image, norm=simple_norm(moving_image, percent = 99), interpolation='nearest')
  plt.title('Moving image')
  plt.axis('off');

  plt.subplot(1,5,3)
  plt.imshow(registered_image, norm=simple_norm(registered_image, percent = 99), interpolation='nearest')
  plt.title("Registered image")
  plt.axis('off');

  plt.subplot(1,5,4)
  plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest', cmap="Greens")
  plt.imshow(moving_image, norm=simple_norm(moving_image, percent = 99), interpolation='nearest', cmap="Oranges", alpha=0.5)
  plt.title("Fixed and moving images")
  plt.axis('off');

  plt.subplot(1,5,5)
  plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest', cmap="Greens")
  plt.imshow(registered_image, norm=simple_norm(registered_image, percent = 99), interpolation='nearest', cmap="Oranges", alpha=0.5)
  plt.title("Fixed and Registered images")
  plt.axis('off');

  plt.show()

## **4.4. Download your predictions**
---

<font size = 4>**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name.

#**Thank you for using DRMIME 2D!**