#**Label-free Prediction - Fnet**
---

<font size = 4> 
Label-free Prediction (Fnet) is a neural network used to infer the features of cellular structures from brightfield or EM images without coloured labels. The network is trained using paired training images from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescent protein). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.

---
<font size = 4> *Disclaimer*:

<font size = 4> This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.

<font size = 4> This notebook is largely based on the paper: **Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by *Chawin Ounkomol, Sharmishtaa Seshamani, Mary M. Maleckar, Forrest Collman & Gregory R. Johnson*  (https://www.nature.com/articles/s41592-018-0111-2)

<font size = 4> And source code found in: https://github.com/AllenCellModeling/pytorch_fnet

<font size = 4> **Please also cite this original paper when using or developing this notebook.** 

# **How to use this notebook?**
---

<font size = 4>Video describing how to use our notebooks are available on youtube:
  - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook
  - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook


---
###**Structure of a notebook**

<font size = 4>The notebook contains two types of cell:  

<font size = 4>**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.

<font size = 4>**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.

---
###**Table of contents, Code snippets** and **Files**

<font size = 4>On the top left side of the notebook you find three tabs which contain from top to bottom:

<font size = 4>*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.

<font size = 4>*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.

<font size = 4>*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. 

<font size = 4>**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.

<font size = 4>**Note:** The "sample data" in "Files" contains default files. Do not upload anything in here!

---
###**Making changes to the notebook**

<font size = 4>**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.

<font size = 4>To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).
You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment.

#**0. Before getting started**
---

<font size = 4> This notebook provides two opportunities: firstly, to download and train Fnet with data published in the original manuscript or secondly, to upload a personal dataset and train Fnet on it.
<font size = 4> The notebook may require a large amount of disk space. If using the datasets from the paper, the available disk space on the user's google drive should contain at least 40GB.

---
<font size = 4>**Data Format**

<font size = 4> **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.

<font size = 4>Information on how to generate a training dataset is available in our Wiki page: https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki

<font size = 4> **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** 


*   Experiment A
    - **Training dataset**
      - bright-field images
        - img_1.tif, img_2.tif, ...
      - fluorescence images
        - img_1.tif, img_2.tif, ...
    - **Quality control dataset**
     - bright-field images
        - img_1.tif, img_2.tif
      - fluorescence images
        - img_1.tif, img_2.tif
    - **Data to be predicted**
    - **Results**

<font size = 4>**Important note**

<font size = 4>- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.

<font size = 4>- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.

<font size = 4>- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.

---


# **1. Initialise the Colab session**




---







## **1.1. Check for GPU access**
---

By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelator: GPU** *(Graphics processing unit)*


In [0]:
#@markdown ##Run this cell to check if you have GPU access
%tensorflow_version 1.x

import tensorflow as tf
if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime settings are correct then Google did not allocate GPU to your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')

from tensorflow.python.client import device_lib 
device_lib.list_local_devices()

## **1.2. Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. 

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [0]:
#@markdown ##Play the cell to connect your Google Drive to Colab

#@markdown * Click on the URL. 

#@markdown * Sign in your Google Account. 

#@markdown * Copy the authorization code. 

#@markdown * Enter the authorization code. 

#@markdown * Click on "Files" site on the right. Refresh the site. Your Google Drive folder should now be available here as "drive". 

# mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/gdrive')

#**2. Install Fnet and dependencies**
---
<font size = 4>Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.

<font size = 4>You can ignore **the error warnings** as they refer to packages not required for this notebook.

<font size = 4>**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**

In [0]:
#@markdown ##Play this cell to download fnet to your drive. If it is already installed this will only install the fnet dependencies.

import os
import csv
import shutil

#Ensure tensorflow 1.x
%tensorflow_version 1.x
import tensorflow
print(tensorflow.__version__)

print("Tensorflow enabled.")

#clone fnet from github to colab
import shutil
import os
!pip install -U scipy==1.2.0
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet'):
  !git clone -b release_1 --single-branch https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .
  shutil.move('/content/pytorch_fnet','/content/gdrive/My Drive/pytorch_fnet')

#**3. Select your paths and parameters**
---
<font size = 5> **Paths for training data**

<font size = 4> **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.

<font size = 4>**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**

<font size = 5> **Training Parameters**

<font size = 4> **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**


<font size = 5>**Advanced Parameters - experienced users only**


<font size =4>**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**

In [0]:
#@markdown ###Datasets
import random
import os
import csv
import shutil
from tempfile import mkstemp
from shutil import move, copymode
from os import fdopen, remove

#In the first step we need to change some parameters in the default files

#This function replaces the old default files with new values
def replace(file_path, pattern, subst):
    #Create temp file
    fh, abs_path = mkstemp()
    with fdopen(fh,'w') as new_file:
        with open(file_path) as old_file:
            for line in old_file:
                new_file.write(line.replace(pattern, subst))
    #Copy the file permissions from the old file to the new file
    copymode(file_path, abs_path)
    #Remove original file
    remove(file_path)
    #Move new file
    move(abs_path, file_path)

#Here we replace values in the old files
#Change maximum pixel number
replace("/content/gdrive/My Drive/pytorch_fnet/fnet/transforms.py",'n_max_pixels=9732096','n_max_pixels=20000000')
replace("/content/gdrive/My Drive/pytorch_fnet/predict.py",'6000000','20000000')

#Prevent resizing in the training and the prediction
replace("/content/gdrive/My Drive/pytorch_fnet/predict.py","0.37241","1.0")
replace("/content/gdrive/My Drive/pytorch_fnet/train_model.py","0.37241","1.0")

#Datasets

#Change checkpoints
replace("/content/gdrive/My Drive/pytorch_fnet/train_model.py","'--interval_save', type=int, default=500","'--interval_save', type=int, default=100")

#Adapt Class Dataset for Tiff files
replace("/content/gdrive/My Drive/pytorch_fnet/train_model.py","'--class_dataset', default='CziDataset'","'--class_dataset', default='TiffDataset'")

#Fetch the path and extract the name of the signal folder
Training_source = "" #@param {type: "string"}
source_name = os.path.basename(os.path.normpath(Training_source))

#Fetch the path and extract the name of the signal folder
Training_target = "" #@param {type: "string"}
target_name =  os.path.basename(os.path.normpath(Training_target))

#@markdown ###Model name and model path
model_name = "" #@param {type:"string"}
model_path = "" #@param {type:"string"}

if os.path.exists(model_path+'/'+model_name):
  shutil.rmtree(model_path+'/'+model_name)
  
dataset = model_name #The name of the dataset and the model will be the same

#Here, we check if the dataset already exists. If not, copy the dataset from google drive to the data folder
  
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset):
  #shutil.copytree(own_dataset,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)
  os.makedirs('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)
  shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
  shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)
elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset) and not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name):
  shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
  shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)
elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name) and os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)
  shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
  shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)

#Create a path_csv file to point to the training images
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')

source = os.listdir('./'+dataset+'/'+source_name)
target = os.listdir('./'+dataset+'/'+target_name)

#print("Selected "+dataset+" as training set")

dataset_x = dataset+"}" # this variable is only used to ensure closed curly brackets when editing the .sh files

#We need to declare that we will run validation on the dataset
#We need to add a new line to the train.sh file
with open("/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh", "r") as f:
  if not "gpu_ids ${GPU_IDS} \\" in f.read():
    replace("/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh","       --gpu_ids ${GPU_IDS}","       --gpu_ids ${GPU_IDS} \\")

#We add the necessary validation parameters here.
f = open("/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh", "r")
contents = f.readlines()
f.close()
f = open("/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh", "r")
if not 'PATH_DATASET_VAL_CSV=' in f.read():
  contents.insert(10, 'PATH_DATASET_VAL_CSV="data/csvs/${DATASET}_val.csv"')
  contents.append('\n       --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}')
f.close()
f = open("/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh", "w")
contents = "".join(contents)
f.write(contents)
f.close()

#Clear the White space from train.sh

with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\
     open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:
    for line in inFile:
        if line.strip():
            outFile.write(line)
os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')
os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')

#Here we define the random set of training files to be used for validation
val_files = random.sample(source,len(source)//10)
if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Input'):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Input')
if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Target'):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Target')

#Make validation directories
os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Input')
os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Target')
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')

#Move a random set of files from the training to the validation folders
for file in val_files:
  shutil.move('./'+dataset+'/'+source_name+'/'+file,'./'+dataset+'/Validation_Input/'+file)
  shutil.move('./'+dataset+'/'+target_name+'/'+file,'./'+dataset+'/Validation_Target/'+file)

#Redefine the source and target lists after moving the validation files
source = os.listdir('./'+dataset+'/'+source_name)
target = os.listdir('./'+dataset+'/'+target_name)

#Define Validation file lists
val_signal = os.listdir('./'+dataset+'/Validation_Input')
val_target = os.listdir('./'+dataset+'/Validation_Target')

if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'_val.csv'):
  os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'_val.csv')

#Finally, we create a validation csv file to construct the validation dataset
with open(dataset+'_val.csv', 'w', newline='') as file:
  writer = csv.writer(file)
  writer.writerow(["path_signal","path_target"])
  for i in range(0,len(val_signal)):
    writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/Validation_Input/"+val_signal[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/Validation_Target/"+val_target[i]])

shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'_val.csv')

if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'.csv'):
  os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'.csv')
with open(dataset+'.csv', 'w', newline='') as file:
  writer = csv.writer(file)
  writer.writerow(["path_signal","path_target"])
  for i in range(0,len(source)):
    writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+source_name+"/"+source[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+target_name+"/"+target[i]])

shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'.csv')

#@markdown ---

#@markdown ###Training Parameters

#Training parameters in fnet are indicated in the train_model.sh file.
#Here, we edit this file to include the desired parameters

#1. Add permissions to train_model.sh
os.chdir("/content/gdrive/My Drive/pytorch_fnet/scripts")
!chmod u+x train_model.sh

#2. Select parameters
steps =  50000#@param {type:"number"}
batch_size =  4#@param {type:"number"}
number_of_images =  len(source)

#3. Insert the above values into train_model.sh
!if ! grep saved_models\/\${ train_model.sh;then sed -i 's/saved_models\/.*/saved_models\/\${DATASET}"/g' train_model.sh; fi 
!sed -i "s/1:-.*/1:-$dataset_x/g" train_model.sh #change the dataset to be trained with
!sed -i "s/N_ITER=.*/N_ITER=$steps/g" train_model.sh #change the number of training iterations (steps)
!sed -i "s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g" train_model.sh #change the number of training images
!sed -i "s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g" train_model.sh #change the batch size

#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.
!sed -i "s/train_size .* -v/train_size 1.0 -v/g" train_model.sh

#If new parameters are inserted here for training a model with the same name
#the previous training csv needs to be removed, to prevent the model using the old training split or paths.
if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset)

#**4. Train the model**
---
<font size = 4>Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.


###**Choose one of the options to train fnet**.

<font size = 4>**4.1.** If this is the first training on the chosen dataset, play this section to start training.

<font size = 4>**4.2.** If you want to continue training on an already pre-trained model choose this section

<font size = 4><font color = red> **Carefully read the options before starting training.**

##**4.1. Train a new model**
---

####Play the cell below to start training. 

<font size = 4>**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3).

In [0]:
import datetime
import time
start = time.time()

#Overwriting old models and saving them separately if True
if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset)

#@markdown ##4.1. Start Training
#with 50 images in buffer from dataset and 5000 epochs, takes around 1:30h

#This tifffile release runs error-free in this version of fnet.
!pip install tifffile==2019.7.26

os.chdir('/content/gdrive/My Drive/pytorch_fnet/fnet/')

#Here we import an additional module to the functions.py file to run it without errors.
#This may be a small bug in the original code.

f = open("functions.py", "r")
contents = f.readlines()
f.close()
f = open("functions.py", "r")
if not 'import fnet.fnet_model' in f.read():
  contents.insert(5, 'import fnet.fnet_model')
f.close()
f = open("functions.py", "w")
contents = "".join(contents)
f.write(contents)
f.close()
os.chdir('/content/gdrive/My Drive/pytorch_fnet/')

#If a model with the same name already exists it will be deleted from the pytorch_fnet/saved_models folder and moved to the model_path folder.
#if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset):
#  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset)

print('Let''s start the training!')
#Here we start the training
!./scripts/train_model.sh $dataset 0

#After training overwrite any existing model in the model_path with the new trained model.
if os.path.exists(model_path+'/'+dataset):
  shutil.rmtree(model_path+'/'+dataset)
shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset,model_path+'/'+dataset)

shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'_val.csv',model_path+'/'+model_name+'/'+dataset+'_val.csv')
#Get rid of duplicates of training data in pytorch_fnet after training completes
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Input')
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Target')


# Displaying the time elapsed for training
dt = time.time() - start
mins, sec = divmod(dt, 60) 
hour, mins = divmod(mins, 60) 
print("Time elapsed:",hour, "hour(s)",mins,"min(s)",round(sec),"sec(s)")


<font size = 4>**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**

In [0]:
#@markdown ##Play this cell if your model training timed out and indicate where you want to save the last checkpoint.

import shutil
import os
model_name = "" #@param {type:"string"}
model_path = "" #@param {type:"string"}

if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):
  shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)
else:
  print('This model name does not exist in your saved_models folder. Make sure you have entered the name of the model that timed out.')

##**4.2. Training from a previously saved model**
---
<font size = 4>This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on.**

<font size = 4>**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**

In [0]:
#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.
#@markdown Enter the paths of the datasets you want to continue training on.
import csv

def replace(file_path, pattern, subst):
    #Create temp file
    fh, abs_path = mkstemp()
    with fdopen(fh,'w') as new_file:
        with open(file_path) as old_file:
            for line in old_file:
                new_file.write(line.replace(pattern, subst))
    #Copy the file permissions from the old file to the new file
    copymode(file_path, abs_path)
    #Remove original file
    remove(file_path)
    #Move new file
    move(abs_path, file_path)

#Here we replace values in the old files
#Change maximum pixel number
replace("/content/gdrive/My Drive/pytorch_fnet/fnet/transforms.py",'n_max_pixels=9732096','n_max_pixels=20000000')
replace("/content/gdrive/My Drive/pytorch_fnet/predict.py",'6000000','20000000')

#Prevent resizing in the training and the prediction
replace("/content/gdrive/My Drive/pytorch_fnet/predict.py","0.37241","1.0")
replace("/content/gdrive/My Drive/pytorch_fnet/train_model.py","0.37241","1.0")

#We add the necessary validation parameters here.
f = open("/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh", "r")
contents = f.readlines()
f.close()
f = open("/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh", "r")
if not 'PATH_DATASET_VAL_CSV=' in f.read():
  contents.insert(10, 'PATH_DATASET_VAL_CSV="data/csvs/${DATASET}_val.csv"')
  contents.append('\n       --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}')
f.close()
f = open("/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh", "w")
contents = "".join(contents)
f.write(contents)
f.close()

#Clear the White space from train.sh

with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\
     open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:
    for line in inFile:
        if line.strip():
            outFile.write(line)
os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')
os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')

#Datasets

#Change checkpoints
replace("/content/gdrive/My Drive/pytorch_fnet/train_model.py","'--interval_save', type=int, default=500","'--interval_save', type=int, default=100")

#Adapt Class Dataset for Tiff files
replace("/content/gdrive/My Drive/pytorch_fnet/train_model.py","'--class_dataset', default='CziDataset'","'--class_dataset', default='TiffDataset'")


Training_source = "" #@param {type: "string"}
source_name = os.path.basename(os.path.normpath(Training_source))

#Fetch the path and extract the name of the signal folder
Training_target = "" #@param {type: "string"}
target_name =  os.path.basename(os.path.normpath(Training_target))

model_name = "" #@param {type:"string"}
dataset = model_name

model_path = "" #@param {type:"string"}

batch_size = 4 #@param {type:"number"}

dataset_x = dataset+"}"

#Move your model to fnet
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset):
  shutil.copytree(model_path+'/'+model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset)

#Move the datasets into fnet
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset):
  os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)
  shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
  shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)
elif not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name):
  shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
  shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)

os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')

### number_of_images =  len(os.listdir(Training_source)) ###

#Change the train_model.sh file to include chosen dataset
!chmod u+x ./scripts/train_model.sh
!sed -i "s/1:-.*/1:-$dataset_x/g" train_model.sh
!sed -i "s/train_size .* -v/train_size 1.0 -v/g" train_model.sh #Use the whole training dataset for training
!sed -i "s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g" train_model.sh #change the number of training images
!sed -i "s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g" train_model.sh #change the batch size


# We will use the same validation files from the training dataset as used before,
# This makes sure that the model is not validated with files it has seen in training before saving.

#First we get the names of the validation files from the previous training which are saved in the validation csv.
val_source_list = []

with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'r') as f:
  contents = csv.reader(f,delimiter=',')
  for row in contents:
    val_source_list.append(row[0])

#Get the file list without the header
val_source_list = val_source_list[1::]

#Get only the file names and not the full path
for i in range(0,len(val_source_list)):
  val_source_list[i] = os.path.basename(os.path.normpath(val_source_list[i]))

source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)

if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Input'):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Input')
if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Target'):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Target')

#Make validation directories
os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Input')
os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/Validation_Target')
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')

#Move a random set of files from the training to the validation folders
for file in val_source_list:
  shutil.move('./'+dataset+'/'+source_name+'/'+file,'./'+dataset+'/Validation_Input/'+file)
  shutil.move('./'+dataset+'/'+target_name+'/'+file,'./'+dataset+'/Validation_Target/'+file)

#Redefine the source and target lists after moving the validation files
source = os.listdir('./'+dataset+'/'+source_name)
target = os.listdir('./'+dataset+'/'+target_name)

#Define Validation file lists
val_signal = os.listdir('./'+dataset+'/Validation_Input')
val_target = os.listdir('./'+dataset+'/Validation_Target')

if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'_val.csv'):
  os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'_val.csv')

shutil.move(model_path+'/'+model_name+'/'+dataset+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'_val.csv')

#Make a training csv file.
if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset)
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')
source = os.listdir('./'+dataset+'/'+source_name)
target = os.listdir('./'+dataset+'/'+target_name)
with open(dataset+'.csv', 'w', newline='') as file:
  writer = csv.writer(file)
  writer.writerow(["path_signal","path_target"])
  for i in range(0,len(source)):
    writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+source_name+"/"+source[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+target_name+"/"+target[i]])

shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'.csv')

#Find the number of previous training iterations (steps) from loss csv file

with open(model_path+'/'+dataset+'/losses.csv') as f:
  previous_steps = sum(1 for line in f)
print('continuing training after step '+str(previous_steps-1))

print('To start re-training play section 4.2. below')

#@markdown For how many additional steps do you want to train the model?
add_steps =  5000#@param {type:"number"}

#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.
new_steps = previous_steps + add_steps -1
os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')

#Edit train_model.sh file to include new total number of training epochs
!sed -i "s/N_ITER=.*/N_ITER=$new_steps/g" train_model.sh

In [0]:
import datetime
import time
start = time.time()

#@markdown ##4.2. Start re-training model
!pip install tifffile==2019.7.26
import os
os.chdir('/content/gdrive/My Drive/pytorch_fnet/fnet')

#Here we import an additional module to the functions.py file to run it without errors.
f = open("functions.py", "r")
contents = f.readlines()
f.close()
f = open("functions.py", "r")
if not 'import fnet.fnet_model' in f.read():
  contents.insert(5, 'import fnet.fnet_model')
f.close()
f = open("functions.py", "w")
contents = "".join(contents)
f.write(contents)
f.close()

#Here we retrain the model on the chosen dataset.
os.chdir('/content/gdrive/My Drive/pytorch_fnet/')
!chmod u+x ./scripts/train_model.sh
!./scripts/train_model.sh $dataset 0

#Update the existing model_path folder with the new training results.
if os.path.exists(model_path+'/'+dataset):
  shutil.rmtree(model_path+'/'+dataset)
shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset,model_path+'/'+dataset)

#Get rid of duplicates of training data in pytorch_fnet after training completes
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_name)
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_name)

shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+dataset+'_val.csv',model_path+'/'+model_name+'/'+dataset+'_val.csv')
# Displaying the time elapsed for training
dt = time.time() - start
min, sec = divmod(dt, 60) 
hour, min = divmod(min, 60) 
print("Time elapsed:",hour, "hour(s)",min,"min(s)",round(sec),"sec(s)")

# **5. Evaluate your model**
---

This section allows the user to perform important quality checks on the validity and generalisability of the trained model. 

**We highly recommend to perform quality control on all newly trained models.**

In [0]:
import os
import shutil
# model name and path
#@markdown ###Do you want to assess the model you just trained ?

Use_the_current_trained_model = False #@param {type:"boolean"}

#@markdown ###If not, please provide the name of the model and path to model folder:
#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter 'model_name' (see section 3). Provide the name of this folder as 'QC_model_name' and the path to its parent folder in 'QC_model_path'. 

QC_model_name = "" #@param {type:"string"}
QC_model_path = "" #@param {type:"string"}

if (Use_the_current_trained_model): 
  print("Using current trained network")
  QC_model_name = model_name
  QC_model_path = model_path

#Create a folder for the quality control metrics
if os.path.exists(QC_model_path+"/"+QC_model_name+"/Quality Control"):
  shutil.rmtree(QC_model_path+"/"+QC_model_name+"/Quality Control")
os.makedirs(QC_model_path+"/"+QC_model_name+"/Quality Control")

full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'
if os.path.exists(full_QC_model_path):
  print("The "+QC_model_name+" network will be evaluated")
else:
  W  = '\033[0m'  # white (normal)
  R  = '\033[31m' # red
  print(R+'!! WARNING: The chosen model does not exist !!'+W)
  print('Please make sure you provide a valid model path and model name before proceeding further.')


## **5.1. Inspection of the loss function**
---

<font size = 4>First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*

<font size = 4>**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.

<font size = 4>**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.

<font size = 4>During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.

<font size = 4>Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.




In [0]:
#@markdown ##Play the cell to show figure of training errors

lossDataFromCSV = []
vallossDataFromCSV = []

iterationNumber_training = []
iterationNumber_val = []

# x = []
# y = []
# x_val = []
# y_val = []
import csv
from matplotlib import pyplot as plt
with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:
    plots = csv.reader(csvfile, delimiter=',')
    next(plots)
    for row in plots:
        iterationNumber_training.append(int(row[0]))
        lossDataFromCSV.append(float(row[1]))

with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:
  plots = csv.reader(csvfile_val, delimiter=',')
  next(plots)
  for row in plots:
    iterationNumber_val.append(int(row[0]))
    vallossDataFromCSV.append(float(row[1]))

# plt.figure(figsize=(16,5))

# plt.plot(x,y, label='Training Loss')
# plt.plot(x_val,y_val, label='Validation Loss')
# plt.title('Model loss')
# plt.ylabel('Loss')
# plt.xlabel('Iteration')
# plt.legend()
# plt.show()

# plt.savefig(QC_model_path+'/'+QC_model_name+'/'+'losses.png')


# epochNumber = range(len(lossDataFromCSV))
plt.figure(figsize=(15,10))

plt.subplot(2,1,1)
plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')
plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')
plt.title('Training loss and validation loss vs. iteration number (linear scale)')
plt.ylabel('Loss')
plt.xlabel('Iteration')
plt.legend()

plt.subplot(2,1,2)
plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')
plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')
plt.title('Training loss and validation loss vs. iteration number (log scale)')
plt.ylabel('Loss')
plt.xlabel('Iteration')
plt.legend()
plt.savefig(QC_model_path+'/'+QC_model_name+'/'+'losses.png')
plt.show()


## **5.2. Error mapping and quality metrics estimation**
---

<font size = 4>This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the "Source_QC_folder" and "Target_QC_folder" !

<font size = 4>**1. The SSIM (structural similarity) map** 

<font size = 4>The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). 

<font size=4>**mSSIM** is the SSIM value calculated across the entire window of both images.

<font size=4>**The output below shows the SSIM maps with the mSSIM**

<font size = 4>**2. The RSE (Root Squared Error) map** 

<font size = 4>This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).


<font size =4>**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.

<font size = 4>**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.

<font size=4>**The output below shows the RSE maps with the NRMSE and PSNR values.**


In [0]:
import shutil  # no need to import these, they're already imported at install
import csv
import os
from tempfile import mkstemp
from shutil import move, copymode
from os import fdopen, remove
from skimage import img_as_float32

from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio as psnr

if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')

!pip install -U scipy==1.2.0
!pip install --no-cache-dir tifffile==2019.7.26 
from distutils.dir_util import copy_tree

# This function replaces values in the fnet files which we need to change.

def replace(file_path, pattern, subst):
    #Create temp file
    fh, abs_path = mkstemp()
    with fdopen(fh,'w') as new_file:
        with open(file_path) as old_file:
            for line in old_file:
                new_file.write(line.replace(pattern, subst))
    #Copy the file permissions from the old file to the new file
    copymode(file_path, abs_path)
    #Remove original file
    remove(file_path)
    #Move new file
    move(abs_path, file_path)

#Here we replace values in the old files
#Change maximum pixel number
replace("/content/gdrive/My Drive/pytorch_fnet/fnet/transforms.py",'n_max_pixels=9732096','n_max_pixels=20000000')
replace("/content/gdrive/My Drive/pytorch_fnet/predict.py",'6000000','20000000')

#Prevent resizing in the training and the prediction
replace("/content/gdrive/My Drive/pytorch_fnet/predict.py","0.37241","1.0")
replace("/content/gdrive/My Drive/pytorch_fnet/train_model.py","0.37241","1.0")


#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#


#Choose the folder with the quality control datasets
Source_QC_folder = "" #@param{type:"string"}
Target_QC_folder = "" #@param{type:"string"}

Predictions_name = "QualityControl" 
Predictions_name_x = Predictions_name+"}"

#If the folder you are creating already exists, delete the existing version to overwrite.
if os.path.exists(QC_model_path+"/"+QC_model_name+"/Quality Control/"+Predictions_name):
  shutil.rmtree(QC_model_path+"/"+QC_model_name+"/Quality Control/"+Predictions_name)

if Use_the_current_trained_model == True:
  #Move the contents of the saved_models folder from your training to the new folder
  #Here, we use a different copyfunction as we only need the contents of the trained_model folder
  copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)
else:
  copy_tree(QC_model_path+'/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)
  dataset = QC_model_name

# Get the name of the folder the test data is in
source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))
target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))

# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.
os.chdir('/content/gdrive/My Drive/pytorch_fnet/')
!chmod u+x /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh
!sed -i "s/1:-.*/1:-$Predictions_name_x/g" /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh

#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.
!sed -i "s/in test.*/in test/g" /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh

#Check that we are using .tif files
file_list = os.listdir(Source_QC_folder)
text = file_list[0]

if text.endswith('.tif') or text.endswith('.tiff'):
  !chmod u+x /content/gdrive/My\ Drive/pytorch_fnet//scripts/predict.sh
  !if ! grep class_dataset /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\/DIR} \\\'$''\n'     --class_dataset TiffDataset \\/' /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh; fi
  !if grep CziDataset /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh; fi   

#Create test_data folder in pytorch_fnet

# If your test data is not in the pytorch_fnet data folder it needs to be copied there.
if Use_the_current_trained_model == True:
  if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name):
    shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name)
    shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+target_dataset_name)
else:
  if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name):
    shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name)
    shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+target_dataset_name)


# Make a folder that will hold the test.csv file in your new folder
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):
  os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)


os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')

#Make a new folder in saved_models to use the trained model for inference.
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):
  os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) 


#Get file list from the folders containing the files you want to use for inference.
#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name)
test_signal = os.listdir(Source_QC_folder)
test_target = os.listdir(Target_QC_folder)
#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')

#If an old test csv exists we want to overwrite it, so we can insert new test data.
if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):
  os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')

#Here we create a new test.csv
with open('test.csv', 'w', newline='') as file:
      writer = csv.writer(file)
      writer.writerow(["path_signal","path_target"])
      for i in range(0,len(test_signal)):
        if Use_the_current_trained_model == True:
          writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+source_dataset_name+"/"+test_signal[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+target_dataset_name+"/"+test_signal[i]])
          # This currently assumes that the names are identical for source and target: see "test_target" variable is never used
        else:
          writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+Predictions_name+"/"+source_dataset_name+"/"+test_signal[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+Predictions_name+"/"+target_dataset_name+"/"+test_signal[i]])

#We run the predictions
os.chdir('/content/gdrive/My Drive/pytorch_fnet/')
!/content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0

#Save the results
QC_results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')

if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction'):
  shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')
os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')

if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal'):
  shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')
os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')

if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Target'):
  shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')
os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')

for i in range(len(QC_results_files)-2):
  shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/prediction_'+Predictions_name+'.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction/'+'Predicted_'+test_signal[i])
  shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Signal/'+test_signal[i])
  shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Target/'+test_target[i])

shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')

if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)


#-----------------------------METRICS EVALUATION-------------------------------#

##@markdown ##Give the paths to an image to test the performance of the model with.
!pip install matplotlib==2.2.3
import sys
import numpy as np
from scipy import signal
from scipy import ndimage
from skimage import io
from matplotlib import pyplot as plt
import pandas as pd
#from skimage.util import img_as_uint
import matplotlib as mpl

def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):
    """This function is adapted from Martin Weigert"""
    """Percentile-based image normalization."""

    mi = np.percentile(x,pmin,axis=axis,keepdims=True)
    ma = np.percentile(x,pmax,axis=axis,keepdims=True)
    return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)


def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32
    """This function is adapted from Martin Weigert"""
    if dtype is not None:
        x   = x.astype(dtype,copy=False)
        mi  = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)
        ma  = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)
        eps = dtype(eps)

    try:
        import numexpr
        x = numexpr.evaluate("(x - mi) / ( ma - mi + eps )")
    except ImportError:
        x =                   (x - mi) / ( ma - mi + eps )

    if clip:
        x = np.clip(x,0,1)

    return x

def norm_minmse(gt, x, normalize_gt=True):
    """This function is adapted from Martin Weigert"""

    """
    normalizes and affinely scales an image pair such that the MSE is minimized  
     
    Parameters
    ----------
    gt: ndarray
        the ground truth image      
    x: ndarray
        the image that will be affinely scaled 
    normalize_gt: bool
        set to True of gt image should be normalized (default)
    Returns
    -------
    gt_scaled, x_scaled 
    """
    if normalize_gt:
        gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)
    x = x.astype(np.float32, copy=False) - np.mean(x)
    #x = x - np.mean(x)
    gt = gt.astype(np.float32, copy=False) - np.mean(gt)
    #gt = gt - np.mean(gt)
    scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())
    return gt, scale * x

# Calculating the position of the mid-plane slice
# Perform prediction on all datasets in the Source_QC folder

#Finding the middle slice
img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))
n_slices = img.shape[0]
z_mid_plane = int(n_slices / 2)+1

path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'

# Open and create the csv file that will contain all the QC metrics
with open(path_metrics_save+'QC_metrics_'+QC_model_name+".csv", "w", newline='') as file:
    writer = csv.writer(file)

    # Write the header in the csv file
    writer.writerow(["File name","Slice #","Prediction v. GT mSSIM", "Prediction v. GT NRMSE", "Prediction v. GT PSNR"])  
    
    # These lists will be used to collect all the metrics values per slice
    file_name_list = []
    slice_number_list = []
    mSSIM_GvP_list = []
    NRMSE_GvP_list = []
    PSNR_GvP_list = []

    # These lists will be used to display the mean metrics for the stacks
    mSSIM_GvP_list_mean = []
    NRMSE_GvP_list_mean = []
    PSNR_GvP_list_mean = []

    # Let's loop through the provided dataset in the QC folders
    for thisFile in os.listdir(Source_QC_folder):
      if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):
        print('Running QC on: '+thisFile)

        test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))
        test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))
        test_prediction_stack = io.imread(os.path.join(path_metrics_save+"Prediction/",'Predicted_'+thisFile))
        test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))
        n_slices = test_GT_stack.shape[0]

        img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))
        img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))

        for z in range(n_slices): 
          
          # -------------------------------- Prediction --------------------------------

          test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)

          # -------------------------------- Calculate the SSIM metric and maps --------------------------------

          # Calculate the SSIM maps and index
          index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True)

          #Calculate ssim_maps
          img_SSIM_GTvsPrediction_stack[z] = np.float32(img_SSIM_GTvsPrediction)
      

          # -------------------------------- Calculate the NRMSE metrics --------------------------------

          # Calculate the Root Squared Error (RSE) maps
          img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))

          # Calculate SE maps
          img_RSE_GTvsPrediction_stack[z] = np.float32(img_RSE_GTvsPrediction)


          # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)
          NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))


          # Calculate the PSNR between the images
          PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)


          writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])
          
          # Collect values to display in dataframe output
          #file_name_list.append(thisFile)
          slice_number_list.append(z)
          mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)

          NRMSE_GvP_list.append(NRMSE_GTvsPrediction)

          PSNR_GvP_list.append(PSNR_GTvsPrediction)


          if (z == z_mid_plane): # catch these for display
            SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction

            NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction

        
        # If calculating average metrics for dataframe output
        file_name_list.append(thisFile)
        mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))

        NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))

        PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))

        # ----------- Change the stacks to 32 bit images -----------
        img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)
        img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)


        # ----------- Saving the error map stacks -----------
        io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)
        io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)

#Averages of the metrics per stack as dataframe output
pdResults = pd.DataFrame(file_name_list, columns = ["File name"])
pdResults["Prediction v. GT mSSIM"] = mSSIM_GvP_list_mean

pdResults["Prediction v. GT NRMSE"] = NRMSE_GvP_list_mean

pdResults["Prediction v. GT PSNR"] = PSNR_GvP_list_mean

pdResults.head()

# All data is now processed saved
Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way

plt.figure(figsize=(15,10))
# Currently only displays the last computed set, from memory

# Target (Ground-truth)
plt.subplot(2,3,1)
plt.axis('off')
img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))
plt.imshow(img_GT[z_mid_plane])
plt.title('Target (slice #'+str(z_mid_plane)+')')


#Setting up colours
cmap = plt.cm.Greys


# Source
plt.subplot(2,3,2)
plt.axis('off')
img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))
plt.imshow(img_Source[z_mid_plane],aspect='equal',cmap=cmap)
plt.title('Source (slice #'+str(z_mid_plane)+')')


#Prediction
plt.subplot(2,3,3)
plt.axis('off')
img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))
img_Prediction = np.squeeze(img_Prediction,axis=(0,))
plt.imshow(img_Prediction[z_mid_plane])
plt.title('Prediction (slice #'+str(z_mid_plane)+')')

#Setting up colours
cmap = plt.cm.CMRmap

#SSIM between GT and Prediction
plt.subplot(2,3,5)
#plt.axis('off')
plt.tick_params(
    axis='both',      # changes apply to the x-axis and y-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,        # ticks along the top edge are off
    left=False,       # ticks along the left edge are off
    right=False,         # ticks along the right edge are off
    labelbottom=False,
    labelleft=False)  
img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))
imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)
plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)
plt.title('SSIM map: Target vs. Prediction',fontsize=15)
plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)


#Root Squared Error between GT and Prediction
plt.subplot(2,3,6)
#plt.axis('off')
plt.tick_params(
    axis='both',      # changes apply to the x-axis and y-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,        # ticks along the top edge are off
    left=False,       # ticks along the left edge are off
    right=False,         # ticks along the right edge are off
    labelbottom=False,
    labelleft=False) 
img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))
imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)
plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)
plt.title('RSE map Target vs. Prediction',fontsize=15)
plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)

print('-----------------------------------')
print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')
pdResults.head()


#**6. Using the trained model**
---

<font size = 4>In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.

## **6.1. Generate prediction(s) from unseen dataset**
---

<font size = 4>The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.

<font size = 4>**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.

<font size = 4>**`Results_folder`:** This folder will contain the predicted output images.

<font size = 4>**`Predictions_name`:** Enter the name under which your images will be stored. This will be added to the file name of the predictions.

<font size = 4>If you want to use a model different from the most recently trained one, untick the box and enter the name of the model in **`Prediction_model_name`** and its location in **`Prediction_model_path`** to use it for prediction. 

**Note: `Prediction_model_name` expects a folder name which contains a model.p file from a previous training.**


In [0]:
#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.
#This is just in case you have already trained on a dataset with the same name
#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.

import shutil  # no need to import these, they're already imported at install
import csv
from tempfile import mkstemp
from shutil import move, copymode
from os import fdopen, remove
import os

if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):
  shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')

!pip install -U scipy==1.2.0
!pip install --no-cache-dir tifffile==2019.7.26 
from distutils.dir_util import copy_tree

def replace(file_path, pattern, subst):
    #Create temp file
    fh, abs_path = mkstemp()
    with fdopen(fh,'w') as new_file:
        with open(file_path) as old_file:
            for line in old_file:
                new_file.write(line.replace(pattern, subst))
    #Copy the file permissions from the old file to the new file
    copymode(file_path, abs_path)
    #Remove original file
    remove(file_path)
    #Move new file
    move(abs_path, file_path)

#Here we replace values in the old files
#Change maximum pixel number
replace("/content/gdrive/My Drive/pytorch_fnet/fnet/transforms.py",'n_max_pixels=9732096','n_max_pixels=20000000')
replace("/content/gdrive/My Drive/pytorch_fnet/predict.py",'6000000','20000000')

#Prevent resizing in the training and the prediction
replace("/content/gdrive/My Drive/pytorch_fnet/predict.py","0.37241","1.0")
replace("/content/gdrive/My Drive/pytorch_fnet/train_model.py","0.37241","1.0")



#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.

Data_folder = "" #@param {type:"string"}
Results_folder = "" #@param {type:"string"}


# Predictions_name = "blabla" #@param {type:"string"}
Predictions_name = 'TempPredictionFolder'
Predictions_name_x = Predictions_name+"}"


#If the folder you are creating already exists, delete the existing version to overwrite.
if os.path.exists(Results_folder+'/'+Predictions_name):
  shutil.rmtree(Results_folder+'/'+Predictions_name)

#@markdown ###Do you want to use the current trained model?

Use_the_current_trained_model = False #@param{type:"boolean"}

#@markdown ###If not, provide the name of the model you want to use 

Prediction_model_name = "" #@param {type:"string"}
Prediction_model_path = "" #@param {type:"string"}

full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'
if os.path.exists(full_Prediction_model_path):
  print("The "+Prediction_model_name+" network will be used.")
else:
  W  = '\033[0m'  # white (normal)
  R  = '\033[31m' # red
  print(R+'!! WARNING: The chosen model does not exist !!'+W)
  print('Please make sure you provide a valid model path and model name before proceeding further.')


if Use_the_current_trained_model:
  #Move the contents of the saved_models folder from your training to the new folder
  #Here, we use a different copyfunction as we only need the contents of the trained_model folder
  copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+dataset,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)
else:
  copy_tree(Prediction_model_path+'/'+Prediction_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)
  dataset = Prediction_model_name

# Get the name of the folder the test data is in
test_dataset_name = os.path.basename(os.path.normpath(Data_folder))

# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.
os.chdir('/content/gdrive/My Drive/pytorch_fnet/')
!chmod u+x /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh
!sed -i "s/1:-.*/1:-$Predictions_name_x/g" /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh

#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.
!sed -i "s/in test.*/in test/g" /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh

#Check that we are using .tif files
file_list = os.listdir(Data_folder)
text = file_list[0]

if text.endswith('.tif') or text.endswith('.tiff'):
  !chmod u+x /content/gdrive/My\ Drive/pytorch_fnet//scripts/predict.sh
  !if ! grep class_dataset /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\/DIR} \\\'$''\n'     --class_dataset TiffDataset \\/' /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh; fi
  !if grep CziDataset /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh; fi   

#Create test_data folder in pytorch_fnet

# If your test data is not in the pytorch_fnet data folder it needs to be copied there.
if Use_the_current_trained_model == True:
  if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name):
    shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name)
else:
  if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name):
    shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name)


# Make a folder that will hold the test.csv file in your new folder
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):
  os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)


os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')

#Make a new folder in saved_models to use the trained model for inference.
if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):
  os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) 


#Get file list from the folders containing the files you want to use for inference.
#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name)
test_signal = os.listdir(Data_folder)

#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.
os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')

#If an old test csv exists we want to overwrite it, so we can insert new test data.
if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):
  os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')

#Here we create a new test.csv
with open('test.csv', 'w', newline='') as file:
      writer = csv.writer(file)
      writer.writerow(["path_signal","path_target"])
      for i in range(0,len(test_signal)):
        if Use_the_current_trained_model ==True:
          writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+test_dataset_name+"/"+test_signal[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+dataset+"/"+test_dataset_name+"/"+test_signal[i]])
        else:
          writer.writerow(["/content/gdrive/My Drive/pytorch_fnet/data/"+Predictions_name+"/"+test_dataset_name+"/"+test_signal[i],"/content/gdrive/My Drive/pytorch_fnet/data/"+Predictions_name+"/"+test_dataset_name+"/"+test_signal[i]])

#We run the predictions
os.chdir('/content/gdrive/My Drive/pytorch_fnet/')
!/content/gdrive/My\ Drive/pytorch_fnet/scripts/predict.sh $Predictions_name 0

#Save the results
results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test')
for i in range(len(results_files)-2):
  #pred_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_folder+'/test/'+results_files[i])
  # shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/prediction_'+Predictions_name+'.tiff', Results_folder+'/'+'Prediction_'+Predictions_name+'_'+str(i)+'.tiff')
  # shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/signal.tiff', Results_folder+'/'+'signal'+'_'+str(i)+'.tiff')
  shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/prediction_'+Predictions_name+'.tiff', Results_folder+'/'+'Prediction_'+test_signal[i])
  shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/3d/'+Predictions_name+'/test/'+results_files[i]+'/signal.tiff', Results_folder+'/'+test_signal[i])

#Comment this out if you want to see the total original results from the prediction in the pytorch_fnet folder.
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')
# shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)

##**6.2. Assess predicted output**
---
<font size = 4>Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize.

In [0]:
!pip install matplotlib==2.2.3
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
import os

##@markdown ###Select the image you would you like to view?
#image = "0" #@param {type:"string"}
os.chdir(Results_folder)

#@markdown ###Select the slice would you like to view?
slice_number =  0#@param {type:"number"}
source_image = io.imread(test_signal[0])
source_image = np.squeeze(source_image, axis=(0,))
prediction_image = io.imread('Prediction_'+test_signal[0])
prediction_image = np.squeeze(prediction_image, axis=(0,))

#Create the figure
fig = plt.figure()
fig.set_figheight(10)
fig.set_figwidth(20)
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

ax1.title.set_text('Source')
ax2.title.set_text('Prediction')

#Setting up colours
cmap = plt.cm.Greys

fig1 = ax1.imshow(source_image[slice_number,:,:], cmap = cmap, aspect = 'equal')
fig1.axes.get_xaxis().set_visible(True)
fig1.axes.get_yaxis().set_visible(True)

fig2 = ax2.imshow(prediction_image[slice_number,:,:], aspect = 'equal')
fig2.axes.get_xaxis().set_visible(True)
fig2.axes.get_yaxis().set_visible(True)

## **6.3. Download your predictions**
---

<font size = 4>**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name.

## **6.4. Purge unnecessary folders**
---


In [0]:
#@markdown ##If you have checked that all your data is saved you can delete the pytorch_fnet folder from your drive by playing this cell.

import shutil
shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet')

#**Thank you for using fnet!**