#**Label-free Prediction (fnet)**
---

<font size = 4> 
Label-free Prediction (Fnet) is a neural network used to infer the features of cellular structures from brightfield or EM images without coloured labels. The network is trained using paired training images from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescent protein). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.

---
<font size = 4> *Disclaimer*:

<font size = 4> This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.

<font size = 4> This notebook is largely based on the paper: **Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by *Chawin Ounkomol, Sharmishtaa Seshamani, Mary M. Maleckar, Forrest Collman & Gregory R. Johnson*  (https://www.nature.com/articles/s41592-018-0111-2)

<font size = 4> And source code found in: https://github.com/AllenCellModeling/pytorch_fnet

<font size = 4> **Please also cite this original paper when using or developing this notebook.** 

# **How to use this notebook?**
---

<font size = 4>Video describing how to use our notebooks are available on youtube:
  - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook
  - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook


---
###**Structure of a notebook**

<font size = 4>The notebook contains two types of cell:  

<font size = 4>**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.

<font size = 4>**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.

---
###**Table of contents, Code snippets** and **Files**

<font size = 4>On the top left side of the notebook you find three tabs which contain from top to bottom:

<font size = 4>*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.

<font size = 4>*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.

<font size = 4>*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. 

<font size = 4>**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.

<font size = 4>**Note:** The "sample data" in "Files" contains default files. Do not upload anything in here!

---
###**Making changes to the notebook**

<font size = 4>**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.

<font size = 4>To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).
You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment.

#**0. Before getting started**
---

<font size = 4> This notebook provides two opportunities: firstly, to download and train Fnet with data published in the original manuscript or secondly, to upload a personal dataset and train Fnet on it.
<font size = 4> The notebook may require a large amount of disk space. If using the datasets from the paper, the available disk space on the user's google drive should contain at least 40GB.

---
<font size = 4>**Data Format**

<font size = 4> **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.

<font size = 4>Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki

<font size = 4> **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** 


*   Experiment A
    - **Training dataset**
      - bright-field images
        - img_1.tif, img_2.tif, ...
      - fluorescence images
        - img_1.tif, img_2.tif, ...
    - **Quality control dataset**
     - bright-field images
        - img_1.tif, img_2.tif
      - fluorescence images
        - img_1.tif, img_2.tif
    - **Data to be predicted**
    - **Results**

<font size = 4>**Important note**

<font size = 4>- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.

<font size = 4>- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.

<font size = 4>- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.

---


# **1. Initialise the Colab session**




---







## **1.1. Check for GPU access**
---

By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelator: GPU** *(Graphics processing unit)*


In [0]:
#@markdown ##Run this cell to check if you have GPU access
%tensorflow_version 1.x

import tensorflow as tf
if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime settings are correct then Google did not allocate GPU to your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')

from tensorflow.python.client import device_lib 
device_lib.list_local_devices()

## **1.2. Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. 

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [0]:
#@markdown ##Play the cell to connect your Google Drive to Colab

#@markdown * Click on the URL. 

#@markdown * Sign in your Google Account. 

#@markdown * Copy the authorization code. 

#@markdown * Enter the authorization code. 

#@markdown * Click on "Files" site on the right. Refresh the site. Your Google Drive folder should now be available here as "drive". 

# mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/gdrive')

#**2. Install Fnet and dependencies**
---
<font size = 4>Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.

<font size = 4>You can ignore **the error warnings** as they refer to packages not required for this notebook.

<font size = 4>**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**

In [0]:
#@markdown ##Play this cell to download fnet to your drive. If it is already installed this will only install the fnet dependencies.

import os
import csv
import shutil

#Ensure tensorflow 1.x
%tensorflow_version 1.x
import tensorflow
print(tensorflow.__version__)

print("Tensorflow enabled.")

#clone fnet from github to colab
import shutil
import os
!pip install -U scipy==1.2.0
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet'):
  !git clone -b release_1 --single-branch https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .
  shutil.move('/content/pytorch_fnet','/content/gdrive/My Drive/pytorch_fnet')

#**3. Select your paths and parameters**
---
<font size = 5> **Paths for training data**

<font size = 4> **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.

<font size = 4>**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**

<font size = 5> **Training Parameters**

<font size = 4> **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**


<font size = 5>**Advanced Parameters - experienced users only**


<font size =4>**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**

<font size = 5>**Troubleshooting:**

<font size = 4> **Not enough disk space to download fnet data**: If colab disconnects during the download, the disk space provided by colab may run out. Usually, this is prevented as data gets moved to the user's google drive. If the problem persists even after restarting the run time, the dataset may need to be downloaded externally and then manually moved to google drive (use the code below, but on your own machine).

In [0]:
#@markdown ###Datasets

import os
import csv
import shutil
from tempfile import mkstemp
from shutil import move, copymode
from os import fdopen, remove

#In the first step we need to change some parameters in the default files

#This function replaces the old default files with new values
def replace(file_path, pattern, subst):
    #Create temp file
    fh, abs_path = mkstemp()
    with fdopen(fh,'w') as new_file:
        with open(file_path) as old_file:
            for line in old_file:
                new_file.write(line.replace(pattern, subst))
    #Copy the file permissions from the old file to the new file
    copymode(file_path, abs_path)
    #Remove original file
    remove(file_path)
    #Move new file
    move(abs_path, file_path)

#Here we replace values in the old files
#Change maximum pixel number
replace("/content/gdrive/My Drive/pytorch_fnet/fnet/transforms.py",'n_max_pixels=9732096','n_max_pixels=20000000')
replace("/content/gdrive/My Drive/pytorch_fnet/predict.py",'6000000','20000000')

#Prevent resizing in the training and the prediction
replace("/content/gdrive/My Drive/pytorch_fnet/predict.py","0.37241","1.0")
replace("/content/gdrive/My Drive/pytorch_fnet/train_model.py","0.37241","1.0")

#Datasets

#Fetch the path and extract the name of the signal folder
Training_source = "" #@param {type: "string"}
source_name = os.path.basename(os.path.normpath(Training_source))

#Fetch the path and extract the name of the signal folder
Training_target = "" #@param {type: "string"}
target_name =  os.path.basename(os.path.normpath(Training_target))

#@markdown ###Model name and model path
model_name = "" #@param {type:"string"}
model_path = "" #@param {type:"string"}

if os.path.exists(model_path+'/'+model_name):
  shutil.rmtree(model_path+'/'+model_name)
  
dataset = model_name #The name of the dataset and the model will be the same

#Here, we check if the dataset already exists. If not, copy the dataset from google drive to the data folder
  
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset):
  #shutil.copytree(own_dataset,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)
  os.makedirs('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)
  shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
  shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)
elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset) and not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name):
  shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
  shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)
  
  #Create a path_csv file to point to the training images
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')
source = os.listdir('./'+dataset+'/'+source_name)
target = os.listdir('./'+dataset+'/'+target_name)
with open(dataset+'.csv', 'w', newline='') as file:
  writer = csv.writer(file)
  writer.writerow(["path_signal","path_target"])
  for i in range(0,len(source)):
    writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+source_name+"/"+source[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+target_name+"/"+target[i]])

shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'.csv')

#print("Selected "+dataset+" as training set")

dataset_x = dataset+"}" # this variable is only used to ensure closed curly brackets when editing the .sh files

#@markdown ---

#@markdown ###Training Parameters

#Training parameters in fnet are indicated in the train_model.sh file.
#Here, we edit this file to include the desired parameters

#1. Add permissions to train_model.sh
os.chdir("/content/gdrive/My Drive/pytorch_fnet/scripts")
!chmod u+x train_model.sh

#2. Select parameters
steps =  50000#@param {type:"number"}
batch_size =  4#@param {type:"number"}
number_of_images =  len(os.listdir(Training_source))

#3. Insert the above values into train_model.sh
!if ! grep saved_models\/\${ train_model.sh;then sed -i 's/saved_models\/.*/saved_models\/\${DATASET}"/g' train_model.sh; fi 
!sed -i "s/1:-.*/1:-$dataset_x/g" train_model.sh #change the dataset to be trained with
!sed -i "s/N_ITER=.*/N_ITER=$steps/g" train_model.sh #change the number of training iterations (steps)
!sed -i "s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g" train_model.sh #change the number of training images
!sed -i "s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g" train_model.sh #change the batch size

#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.
!sed -i "s/train_size .* -v/train_size 1.0 -v/g" train_model.sh

#4. class_dataset
#Fnet uses specific dataset classes depending on the type of input data the user uses.
#Here, we use the path_csv file to find out if the dataset needs to be a tiff or czi dataset.

mycsv = csv.reader(open('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'.csv'))
for row in mycsv:
   text = row[0]
os.chdir("/content/gdrive/My Drive/pytorch_fnet/")

#If the dataset has tif or tiff files, and the train_model.sh has either no class_dataset argument 
#or class_dataset = CziDataset we change it to TiffDataset
if text.endswith('.tif') or text.endswith('.tiff'):
  !chmod u+x ./scripts/train_model.sh
  !if ! grep class_dataset ./scripts/train_model.sh;then sed -i 's/ITER} \\/ITER} \\\'$''\n'       --class_dataset TiffDataset \\/' ./scripts/train_model.sh; fi
  !if grep CziDataset ./scripts/train_model.sh;then sed -i 's/CziDataset/TiffDataset/' ./scripts/train_model.sh; fi

#If new parameters are inserted here for training a model with the same name
#the previous training csv needs to be removed, to prevent the model using the old training split or paths.
if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset)

#**4. Train the model**
---
<font size = 4>Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.


###**Choose one of the options to train fnet**.

<font size = 4>**4.1.** If this is the first training on the chosen dataset, play this section to start training.

<font size = 4>**4.2.** If you want to continue training on an already pre-trained model choose this section

<font size = 4><font color = red> **Carefully read the options before starting training.**

##**4.1. Train a new model**
---

####Play the cell below to start training. 

<font size = 4>**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3).

In [0]:
import datetime
import time
start = time.time()

#Overwriting old models and saving them separately if True
if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset)

#@markdown ##4.1. Start Training
#with 50 images in buffer from dataset and 5000 epochs, takes around 1:30h

#This tifffile release runs error-free in this version of fnet.
!pip install tifffile==2019.7.26

os.chdir('/content/gdrive/My Drive/pytorch_fnet/fnet/')

#Here we import an additional module to the functions.py file to run it without errors.
#This may be a small bug in the original code.

f = open("functions.py", "r")
contents = f.readlines()
f.close()
f = open("functions.py", "r")
if not 'import fnet.fnet_model' in f.read():
  contents.insert(5, 'import fnet.fnet_model')
f.close()
f = open("functions.py", "w")
contents = "".join(contents)
f.write(contents)
f.close()
os.chdir('/content/gdrive/My Drive/pytorch_fnet/')

#If a model with the same name already exists it will be deleted from the pytorch_fnet/saved_models folder and moved to the model_path folder.
#if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset):
#  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset)

print('Let''s start the training!')
#Here we start the training
!./scripts/train_model.sh $dataset 0

#After training overwrite any existing model in the model_path with the new trained model.
if os.path.exists(model_path+'/'+dataset):
  shutil.rmtree(model_path+'/'+dataset)
shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset,model_path+'/'+dataset)

#Get rid of duplicates of training data in pytorch_fnet after training completes
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)

# Displaying the time elapsed for training
dt = time.time() - start
min, sec = divmod(dt, 60) 
hour, min = divmod(min, 60) 
print("Time elapsed:",hour, "hour(s)",min,"min(s)",round(sec),"sec(s)")


<font size = 4>**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**

In [0]:
#@markdown ##Play this cell if your model training timed out and indicate where you want to save the last checkpoint.

import shutil
import os
model_name = "" #@param {type:"string"}
model_path = "" #@param {type:"string"}

if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):
  shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)
else:
  print('This model name does not exist in your saved_models folder. Make sure you have entered the name of the model that timed out.')

##**4.2. Training from a previously saved model**
---
<font size = 4>This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, save it separately now as this section will overwrite the weights of the old model.

<font size = 4>**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**

In [0]:
#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.
#@markdown Enter the paths of the datasets you want to continue training on.
import csv

Training_source = "" #@param {type: "string"}
source_name = os.path.basename(os.path.normpath(Training_source))

#Fetch the path and extract the name of the signal folder
Training_target = "" #@param {type: "string"}
target_name =  os.path.basename(os.path.normpath(Training_target))

model_name = "" #@param {type:"string"}
dataset = model_name

model_path = "" #@param {type:"string"}

batch_size = 4 #@param {type:"number"}

dataset_x = dataset+"}"

#Move your model to fnet
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset):
  shutil.copytree(model_path+'/'+model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset)

#Move the datasets into fnet
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset):
  os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)
  shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
  shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)
elif not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name):
  shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
  shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)

os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')

number_of_images =  len(os.listdir(Training_source))
#Change the train_model.sh file to include chosen dataset
!chmod u+x ./scripts/train_model.sh
!sed -i "s/1:-.*/1:-$dataset_x/g" train_model.sh
!sed -i "s/train_size .* -v/train_size 1.0 -v/g" train_model.sh #Use the whole training dataset for training
!sed -i "s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g" train_model.sh #change the number of training images
!sed -i "s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g" train_model.sh #change the batch size

#Make sure a training csv file exists and if not make it from the training folders.
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset):
  os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')
  source = os.listdir('./'+dataset+'/'+source_name)
  target = os.listdir('./'+dataset+'/'+target_name)
  with open(dataset+'.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["path_signal","path_target"])
    for i in range(0,len(source)):
      writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+source_name+"/"+source[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+target_name+"/"+target[i]])

  shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'.csv')

#Make sure you're using .tif files as above in section 3.
mycsv = csv.reader(open('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'.csv'))
for row in mycsv:
   text = row[0]
os.chdir("/content/gdrive/My Drive/pytorch_fnet/")

#If the dataset has tif or tiff files, and the train_model.sh has either no class_dataset argument 
#or class_dataset = CziDataset we change it to TiffDataset
if text.endswith('.tif') or text.endswith('.tiff'):
  !chmod u+x ./scripts/train_model.sh
  !if ! grep class_dataset ./scripts/train_model.sh;then sed -i 's/ITER} \\/ITER} \\\'$''\n'       --class_dataset TiffDataset \\/' ./scripts/train_model.sh; fi
  !if grep CziDataset ./scripts/train_model.sh;then sed -i 's/CziDataset/TiffDataset/' ./scripts/train_model.sh; fi


#Find the number of previous training iterations (steps) from loss csv file

with open(model_path+'/'+dataset+'/losses.csv') as f:
  previous_steps = sum(1 for line in f)
print('continuing training after step '+str(previous_steps-1))

print('To start re-training play section 4.2. below')

#@markdown For how many additional steps do you want to train the model?
add_steps =  5000#@param {type:"number"}

#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.
new_steps = previous_steps + add_steps -1
os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')

#Edit train_model.sh file to include new total number of training epochs
!sed -i "s/N_ITER=.*/N_ITER=$new_steps/g" train_model.sh

In [0]:
import datetime
import time
start = time.time()

#@markdown ##4.2. Start re-training model
!pip install tifffile==2019.7.26
import os
os.chdir('/content/gdrive/My Drive/pytorch_fnet/fnet')

#Here we import an additional module to the functions.py file to run it without errors.
f = open("functions.py", "r")
contents = f.readlines()
f.close()
f = open("functions.py", "r")
if not 'import fnet.fnet_model' in f.read():
  contents.insert(5, 'import fnet.fnet_model')
f.close()
f = open("functions.py", "w")
contents = "".join(contents)
f.write(contents)
f.close()

#Here we retrain the model on the chosen dataset.
os.chdir('/content/gdrive/My Drive/pytorch_fnet/')
!chmod u+x ./scripts/train_model.sh
!./scripts/train_model.sh $dataset 0

#Update the existing model_path folder with the new training results.
if os.path.exists(model_path+'/'+dataset):
  shutil.rmtree(model_path+'/'+dataset)
shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset,model_path+'/'+dataset)

#Get rid of duplicates of training data in pytorch_fnet after training completes
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)

# Displaying the time elapsed for training
dt = time.time() - start
min, sec = divmod(dt, 60) 
hour, min = divmod(min, 60) 
print("Time elapsed:",hour, "hour(s)",min,"min(s)",round(sec),"sec(s)")

## **4.3. Download your model(s) from Google Drive**
---

<font size = 4>Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder.

# **5. Evaluate your model**
---

This section allows the user to perform important quality checks on the validity and generalisability of the trained model. 

**We highly recommend to perform quality control on all newly trained models.**

In [0]:
import os
import shutil
# model name and path
#@markdown ###Do you want to assess the model you just trained ?

Use_the_current_trained_model = True #@param {type:"boolean"}

#@markdown ###If not, please provide the name of the model and path to model folder:
#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter 'model_name' (see section 3). Provide the name of this folder as 'QC_model_name' and the path to its parent folder in 'QC_model_path'. 

QC_model_name = "" #@param {type:"string"}
QC_model_path = "" #@param {type:"string"}

if (Use_the_current_trained_model): 
  print("Using current trained network")
  QC_model_name = model_name
  QC_model_path = model_path

#Create a folder for the quality control metrics
if os.path.exists(QC_model_path+"/"+QC_model_name+"/Quality Control"):
  shutil.rmtree(QC_model_path+"/"+QC_model_name+"/Quality Control")
os.makedirs(QC_model_path+"/"+QC_model_name+"/Quality Control")

print("The "+QC_model_name+" network will be evaluated")

## **5.1. Inspection of the loss function**
---

<font size = 4>It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*

<font size = 4>**Loss** <code>(loss)</code> describes an error value after each epoch for the difference between the model's prediction and its ground-truth ('GT') target.

**Note:** Validation Loss is currently not implemented in this notebook.



In [0]:
#@markdown ##Play the cell to show figure of training errors

x = []
y = []
import csv
from matplotlib import pyplot as plt
with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:
    plots = csv.reader(csvfile, delimiter=',')
    next(plots)
    for row in plots:
        x.append(int(row[0]))
        y.append(float(row[1]))

plt.figure(figsize=(16,5))

plt.plot(x,y)
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Iteration')
plt.show()

plt.savefig(QC_model_path+'/'+QC_model_name+'/'+'losses.png')

## **5.2. Error mapping and quality metrics estimation**
---

<font size = 4>This section will display Square Error maps and SSIM maps as well as calculating NRMSE and SSIM metrics for all the images provided in the "Source_QC_folder" and "Target_QC_folder" !

<font size = 4>**The Square Error map** display the square of the difference between the normalized predicted and target or the source and the target. In this case, a smaller SE is better. A perfect agreement between target and prediction will lead to an image showing zeros everywhere.

<font size = 4>**The SSIM (structural similarity)** is a common metric comparing whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps calculates the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defiend as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). 




In [0]:
import shutil  # no need to import these, they're already imported at install
import csv
import os

if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')

!pip install -U scipy==1.2.0
!pip install --no-cache-dir tifffile==2019.7.26 
from distutils.dir_util import copy_tree


#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#


#Choose the folder with the quality control datasets
Source_QC_folder = "" #@param{type:"string"}
Target_QC_folder = "" #@param{type:"string"}

Predictions_name = "QualityControl" 
Predictions_name_x = Predictions_name+"}"

#If the folder you are creating already exists, delete the existing version to overwrite.
if os.path.exists(QC_model_path+"/"+QC_model_name+"/Quality Control/"+Predictions_name):
  shutil.rmtree(QC_model_path+"/"+QC_model_name+"/Quality Control/"+Predictions_name)

if Use_the_current_trained_model == True:
  #Move the contents of the saved_models folder from your training to the new folder
  #Here, we use a different copyfunction as we only need the contents of the trained_model folder
  copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)
else:
  copy_tree(QC_model_path+'/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)
  dataset = QC_model_name

# Get the name of the folder the test data is in
source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))
target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))

# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.
os.chdir('/content/gdrive/My Drive/pytorch_fnet/')
!chmod u+x /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh
!sed -i "s/1:-.*/1:-$Predictions_name_x/g" /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh

#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.
!sed -i "s/in test.*/in test/g" /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh

#Check that we are using .tif files
file_list = os.listdir(Source_QC_folder)
text = file_list[0]

if text.endswith('.tif') or text.endswith('.tiff'):
  !chmod u+x /content/gdrive/My\ Drive/pytorch_fnet//scripts/predict.sh
  !if ! grep class_dataset /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\/DIR} \\\'$''\n'     --class_dataset TiffDataset \\/' /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh; fi
  !if grep CziDataset /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh; fi   

#Create test_data folder in pytorch_fnet

# If your test data is not in the pytorch_fnet data folder it needs to be copied there.
if Use_the_current_trained_model == True:
  if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name):
    shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name)
    shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_dataset_name)
else:
  if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name):
    shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name)
    shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+target_dataset_name)


# Make a folder that will hold the test.csv file in your new folder
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):
  os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)


os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')

#Make a new folder in saved_models to use the trained model for inference.
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):
  os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) 


#Get file list from the folders containing the files you want to use for inference.
#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name)
test_signal = os.listdir(Source_QC_folder)
test_target = os.listdir(Target_QC_folder)
#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')

#If an old test csv exists we want to overwrite it, so we can insert new test data.
if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):
  os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')

#Here we create a new test.csv
with open('test.csv', 'w', newline='') as file:
      writer = csv.writer(file)
      writer.writerow(["path_signal","path_target"])
      for i in range(0,len(test_signal)):
        if Use_the_current_trained_model == True:
          writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+source_dataset_name+"/"+test_signal[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+target_dataset_name+"/"+test_signal[i]])
          # This currently assumes that the names are identical for source and target: see "test_target" variable is never used
        else:
          writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+Predictions_name+"/"+source_dataset_name+"/"+test_signal[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+Predictions_name+"/"+target_dataset_name+"/"+test_signal[i]])

#We run the predictions
os.chdir('/content/gdrive/My Drive/pytorch_fnet/')
!/content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0

#Save the results
QC_results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')

if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction'):
  shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')
os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')

if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal'):
  shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')
os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')

if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Target'):
  shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')
os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')

for i in range(len(QC_results_files)-2):
  #pred_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_folder+'/test/'+results_files[i])
  # shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/prediction_'+Predictions_name+'.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction/'+str(i)+'.tiff')
  # shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Signal/'+str(i)+'.tiff')
  # shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Target/'+str(i)+'.tiff')
  shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/prediction_'+Predictions_name+'.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction/'+'Predicted_'+test_signal[i])
  shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Signal/'+test_signal[i])
  shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Target/'+test_target[i])

shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')

if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)


#-----------------------------METRICS EVALUATION-------------------------------#

##@markdown ##Give the paths to an image to test the performance of the model with.
!pip install matplotlib==2.2.3
from  skimage.metrics import structural_similarity as ssim_index
import sys
import numpy as np
from scipy import signal
from scipy import ndimage
from skimage import io
from matplotlib import pyplot as plt
from sklearn.linear_model import LinearRegression
from skimage.util import img_as_uint
import matplotlib as mpl

def gauss(size, sigma):

    """This function is used to create a window for the calculation of ssim, according to Zhou et al.
    """
    x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
    g = np.exp(-((x**2 + y**2)/(2.0*sigma**2)))
    return g/g.sum()

def ssim(img1, img2, cs_map=False):
    """Return the Structural Similarity Map corresponding to input images img1 
    and img2 (images are assumed to be uint8)
    
    Addendum: NOW ASSUMING 16 bits!!!
    
    This function attempts to mimic precisely the functionality of ssim.m a 
    MATLAB provided by the author's of SSIM
    https://ece.uwaterloo.ca/~z70wang/research/ssim/ssim_index.m
    """
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)

    # Currently fixed patch size and sigma size
    size = 11
    sigma = 1.5
    window = gauss(size, sigma)
    K1 = 0.01
    K2 = 0.03
    # L = 255 #bitdepth of image
    L = 65535 #bitdepth of image

    C1 = (K1*L)**2
    C2 = (K2*L)**2
    mu1 = signal.fftconvolve(window, img1, mode='valid')
    mu2 = signal.fftconvolve(window, img2, mode='valid')
    mu1_sq = mu1*mu1
    mu2_sq = mu2*mu2
    mu1_mu2 = mu1*mu2
    sigma1_sq = signal.fftconvolve(window, img1*img1, mode='valid') - mu1_sq
    sigma2_sq = signal.fftconvolve(window, img2*img2, mode='valid') - mu2_sq
    sigma12 = signal.fftconvolve(window, img1*img2, mode='valid') - mu1_mu2
    #if cs_map:
    #    return (((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*
    #                (sigma1_sq + sigma2_sq + C2)), 
    #            (2.0*sigma12 + C2)/(sigma1_sq + sigma2_sq + C2))
    #else:
    return ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

def normalizeImageWithPercentile(img):
  # Normalisation of the predicted image for conversion to uint8 or 16bits image (necessary for SSIM)
  img_max = np.percentile(img,99.9,interpolation='nearest')
  img_min = np.percentile(img,0.1,interpolation='nearest')
  return (img - img_min)/(img_max - img_min) # For normalisation between 0 and 1.

def clipImageMinAndMax(img, min, max):
  img_clipped = np.where(img > max, max, img) 
  img_clipped = np.where(img_clipped < min, min, img_clipped)
  return img_clipped

def normalizeByLinearRegression(img1, img2):
  # Perform the fit
  linreg = LinearRegression().fit(np.reshape(img1.flatten(),(-1,1)), np.reshape(img2.flatten(), (-1,1)))

  # Get parameters of the regression fit.
  alpha = linreg.coef_
  beta = linreg.intercept_
  # print('alpha: '+str(alpha))
  # print('beta: '+str(beta))

  return img1*alpha + beta

def ssim_index(img1,img2):
  # This function calculates a SSIM index between two images. 
  # Note that the images need be suitably normalised as below.

  img1 = img1.astype(np.float64)
  img2 = img2.astype(np.float64)

  L = 65535
  c1 = (0.01*L)**2
  c2 = (0.03*L)**2
  
  #Mean of the images
  mu_img1 = np.mean(img1)
  mu_img2 = np.mean(img2)
  
  #Variance of the images
  var_img1 = np.mean(np.square(img1-mu_img1))
  var_img2 = np.mean(np.square(img2-mu_img2))
  
  #Covariance of the images
  cov_img1vsimg2 = np.mean((img1-mu_img1)*(img2-mu_img2))

  Numerator = (2*mu_img1*mu_img2+c1)*(2*cov_img1vsimg2+c2)
  Denominator = (mu_img1**2+mu_img2**2+c1)*(var_img1+var_img2+c2)

  SSIM_index = Numerator/Denominator

  return SSIM_index

# Open and create the csv file that will contain all the QC metrics
with open(QC_model_path+"/"+QC_model_name+"/Quality Control/QC_metrics_"+QC_model_name+".csv", "w", newline='') as file:
    writer = csv.writer(file)

    # Write the header in the csv file
    writer.writerow(["image #","Prediction v. GT mSSIM","Prediction v. GT NRMSE"])
    # print('Number of files: '+str(len(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')))

    # Let's loop through the provided dataset in the QC folders
    for thisSourceFileName in os.listdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal'):
      if not os.path.isdir(os.path.join(Source_QC_folder, thisSourceFileName)):

        print('Running QC on: '+thisSourceFileName)
        # -------------------------------- Target test data (Ground truth) --------------------------------
        test_GT = io.imread(os.path.join(QC_model_path+'/'+QC_model_name+'/Quality Control/Target', thisSourceFileName))
        test_GT = np.squeeze(test_GT,axis=(0,))
        test_GT_MIP = np.max(test_GT, axis=0)
        test_GT_norm = normalizeImageWithPercentile(test_GT_MIP) # For normalisation between 0 and 1.

        # -------------------------------- Prediction --------------------------------
        test_prediction = io.imread(os.path.join(QC_model_path+"/"+QC_model_name+"/Quality Control/Prediction",'Predicted_'+thisSourceFileName))
        test_prediction = np.squeeze(test_prediction,axis=(0,))
        test_prediction_MIP = np.max(test_prediction,axis=0)
        test_prediction_norm = normalizeImageWithPercentile(test_prediction_MIP) # For normalisation between 0 and 1.
        # Normalize the image further via linear regression wrt the normalised GT image
        test_prediction_norm = normalizeByLinearRegression(test_prediction_norm, test_GT_norm)

        # -------------------------------- Calculate the metric maps and save them --------------------------------

        # Calculate the SSIM images based on the default window parameters defined in the function
        GTforSSIM = img_as_uint(clipImageMinAndMax(test_GT_norm,0, 1), force_copy = True)
        PredictionForSSIM = img_as_uint(clipImageMinAndMax(test_prediction_norm,0, 1), force_copy = True)

        # Calculate the SSIM maps
        img_SSIM_GTvsPrediction = ssim(GTforSSIM, PredictionForSSIM)

        #Save ssim_maps
        img_SSIM_GTvsPrediction_32bit=np.float32(img_SSIM_GTvsPrediction)
        io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+thisSourceFileName,img_SSIM_GTvsPrediction_32bit)
      
        # Calculate the Root Squared Error (RSE) maps
        img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))

        # Save SE maps
        img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)
        io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+thisSourceFileName,img_RSE_GTvsPrediction_32bit)



        # -------------------------------- Calculate the metrics and save them --------------------------------

        # Calculate the mean SSIM metric
        index_SSIM_GTvsPrediction = ssim_index(GTforSSIM, PredictionForSSIM)

        # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)
        NRMSE_GTvsPrediction = np.mean(img_RSE_GTvsPrediction)
        writer.writerow([i,str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction)])



# All data is now processed saved
Test_FileList = os.listdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal') # this assumes, as it should, that both source and target are named the same

plt.figure(figsize=(15,10))
# Currently only displays the last computed set, from memory
# Target (Ground-truth)
plt.subplot(2,3,1)
plt.axis('off')
img_GT = io.imread(os.path.join(QC_model_path+'/'+QC_model_name+'/Quality Control/Target', Test_FileList[-1]))
img_GT = np.squeeze(img_GT,axis=(0,))
img_GT_MIP = np.max(img_GT, axis=0)
plt.imshow(img_GT_MIP, aspect='equal')
plt.title('MIP Target')

#Setting up colours
cmap = plt.cm.Greys

# Source
plt.subplot(2,3,2)
plt.axis('off')
img_Source = io.imread(os.path.join(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal', Test_FileList[-1]))
img_Source = np.squeeze(img_Source,axis=(0,))
img_Source_MIP = np.max(img_Source,axis=0)
plt.imshow(img_Source_MIP, aspect='equal', cmap = cmap)
plt.title('MIP Source')

#Prediction
plt.subplot(2,3,3)
plt.axis('off')
img_Prediction = io.imread(os.path.join(QC_model_path+"/"+QC_model_name+"/Quality Control/Prediction", 'Predicted_'+Test_FileList[-1]))
img_Prediction = np.squeeze(img_Prediction,axis=(0,))
img_Prediction_MIP = np.max(img_Prediction,axis=0)
plt.imshow(img_Prediction_MIP, aspect='equal')
plt.title('MIP Prediction')

#Setting up colours
cmap = plt.cm.CMRmap

#SSIM between GT and Prediction
plt.subplot(2,3,5)
plt.axis('off')
imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, aspect='equal', cmap = cmap, vmin=0,vmax=1)
plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)
plt.title('MIP Target vs. MIP Prediction SSIM: '+str(round(index_SSIM_GTvsPrediction,3)))

#Root Squared Error between GT and Prediction
plt.subplot(2,3,6)
plt.axis('off')
imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, aspect='equal', cmap = cmap, vmin=0, vmax=1)
plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)
plt.title('MIP Target vs. MIP Prediction NRMSE: '+str(round(NRMSE_GTvsPrediction,3)));




#**6. Use the network**
---

<font size = 4>In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.

## **6.1. Generate prediction(s) from unseen dataset**
---

<font size = 4>The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.

<font size = 4>**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.

<font size = 4>**`Results_folder`:** This folder will contain the predicted output images.

<font size = 4>**`Predictions_name`:** Enter the name under which your images will be stored. This will be added to the file name of the predictions.

<font size = 4>If you want to use a model different from the most recently trained one, untick the box and enter the name of the model in **`Prediction_model_name`** and its location in **`Prediction_model_path`** to use it for prediction. 

**Note: `Prediction_model_name` expects a folder name which contains a model.p file from a previous training.**


In [0]:
#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.
#This is just in case you have already trained on a dataset with the same name
#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.

import shutil  # no need to import these, they're already imported at install
import csv

import os

if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')

!pip install -U scipy==1.2.0
!pip install --no-cache-dir tifffile==2019.7.26 
from distutils.dir_util import copy_tree

#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.

Data_folder = "" #@param {type:"string"}
Results_folder = "" #@param {type:"string"}


# Predictions_name = "blabla" #@param {type:"string"}
Predictions_name = 'TempPredictionFolder'
Predictions_name_x = Predictions_name+"}"

#If the folder you are creating already exists, delete the existing version to overwrite.
if os.path.exists(Results_folder+'/'+Predictions_name):
  shutil.rmtree(Results_folder+'/'+Predictions_name)

#@markdown ###Do you want to use the current trained model?

use_the_current_trained_model = True #@param{type:"boolean"}

#@markdown ###If not, provide the name of the model you want to use 

Prediction_model_name = "" #@param {type:"string"}
Prediction_model_path = "" #@param {type:"string"}


if use_the_current_trained_model == True:
  #Move the contents of the saved_models folder from your training to the new folder
  #Here, we use a different copyfunction as we only need the contents of the trained_model folder
  copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)
else:
  copy_tree(Prediction_model_path+'/'+Prediction_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)
  dataset = Prediction_model_name

# Get the name of the folder the test data is in
test_dataset_name = os.path.basename(os.path.normpath(Data_folder))

# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.
os.chdir('/content/gdrive/My Drive/pytorch_fnet/')
!chmod u+x /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh
!sed -i "s/1:-.*/1:-$Predictions_name_x/g" /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh

#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.
!sed -i "s/in test.*/in test/g" /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh

#Check that we are using .tif files
file_list = os.listdir(Data_folder)
text = file_list[0]

if text.endswith('.tif') or text.endswith('.tiff'):
  !chmod u+x /content/gdrive/My\ Drive/pytorch_fnet//scripts/predict.sh
  !if ! grep class_dataset /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\/DIR} \\\'$''\n'     --class_dataset TiffDataset \\/' /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh; fi
  !if grep CziDataset /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh; fi   

#Create test_data folder in pytorch_fnet

# If your test data is not in the pytorch_fnet data folder it needs to be copied there.
if use_the_current_trained_model == True:
  if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name):
    shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name)
else:
  if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name):
    shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name)


# Make a folder that will hold the test.csv file in your new folder
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):
  os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)


os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')

#Make a new folder in saved_models to use the trained model for inference.
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):
  os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) 


#Get file list from the folders containing the files you want to use for inference.
#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name)
test_signal = os.listdir(Data_folder)

#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')

#If an old test csv exists we want to overwrite it, so we can insert new test data.
if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):
  os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')

#Here we create a new test.csv
with open('test.csv', 'w', newline='') as file:
      writer = csv.writer(file)
      writer.writerow(["path_signal","path_target"])
      for i in range(0,len(test_signal)):
        if use_the_current_trained_model ==True:
          writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+test_dataset_name+"/"+test_signal[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+test_dataset_name+"/"+test_signal[i]])
        else:
          writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+Predictions_name+"/"+test_dataset_name+"/"+test_signal[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+Predictions_name+"/"+test_dataset_name+"/"+test_signal[i]])

#We run the predictions
os.chdir('/content/gdrive/My Drive/pytorch_fnet/')
!/content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0

#Save the results
results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')
for i in range(len(results_files)-2):
  #pred_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_folder+'/test/'+results_files[i])
  # shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/prediction_'+Predictions_name+'.tiff', Results_folder+'/'+'Prediction_'+Predictions_name+'_'+str(i)+'.tiff')
  # shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/signal.tiff', Results_folder+'/'+'signal'+'_'+str(i)+'.tiff')
  shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/prediction_'+Predictions_name+'.tiff', Results_folder+'/'+'Prediction_'+test_signal[i])
  shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/signal.tiff', Results_folder+'/'+test_signal[i])

#Comment this out if you want to see the total original results from the prediction in the pytorch_fnet folder.
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')
# shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)

##**6.2. Assess predicted output**
---
<font size = 4>Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize.
<font size = 4>**Note:** Fnet reshapes the dimensions of images for training, which may result in a loss of resolution in the output images.

In [0]:
!pip install matplotlib==2.2.3
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
import os

##@markdown ###Select the image you would you like to view?
#image = "0" #@param {type:"string"}
os.chdir(Results_folder)

#@markdown ###Select the slice would you like to view?
slice_number = 16 #@param {type:"number"}
source_image = io.imread(test_signal[0])
source_image = np.squeeze(source_image, axis=(0,))
prediction_image = io.imread('Prediction_'+test_signal[0])
prediction_image = np.squeeze(prediction_image, axis=(0,))

#Create the figure
fig = plt.figure()
fig.set_figheight(10)
fig.set_figwidth(20)
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

ax1.title.set_text('Source')
ax2.title.set_text('Prediction')

#Setting up colours
cmap = plt.cm.Greys

fig1 = ax1.imshow(source_image[slice_number,:,:], cmap = cmap, aspect = 'equal')
fig1.axes.get_xaxis().set_visible(True)
fig1.axes.get_yaxis().set_visible(True)

fig2 = ax2.imshow(prediction_image[slice_number,:,:], aspect = 'equal')
fig2.axes.get_xaxis().set_visible(True)
fig2.axes.get_yaxis().set_visible(True)

## **6.3. Download your predicted output**
---

<font size = 4>**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that notebook will otherwise **OVERWRITE** all files which have the same name.

In [0]:
#@markdown ##If you have checked that all your data is saved you can delete the pytorch_fnet folder from your drive by playing this cell.

import shutil
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet')