# **Deep-STORM (2D)**

---

<font size = 4>Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).

Deep-STORM has **two key advantages**:
- SMLM reconstruction at high density of emitters
- fast prediction (reconstruction) once the model is trained appropriately, compared to more common multi-emitter fitting processes.


---

<font size = 4>*Disclaimer*:

<font size = 4>This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.

<font size = 4>This notebook is based on the following paper:

<font size = 4>**Deep-STORM: super-resolution single-molecule microscopy by deep learning**, Optica (2018) by *Elias Nehme, Lucien E. Weiss, Tomer Michaeli, and Yoav Shechtman* (https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458)

<font size = 4>And source code found in: https://github.com/EliasNehme/Deep-STORM


<font size = 4>**Please also cite this original paper when using or developing this notebook.**

# **How to use this notebook?**

---

<font size = 4>Video describing how to use our notebooks are available on youtube:
  - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook
  - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook


---
### **Structure of a notebook**

<font size = 4>The notebook contains two types of cell:  

<font size = 4>**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.

<font size = 4>**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.

---
### **Table of contents, Code snippets** and **Files**

<font size = 4>On the top left side of the notebook you find three tabs which contain from top to bottom:

<font size = 4>*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.

<font size = 4>*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.

<font size = 4>*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here.

<font size = 4>**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.

<font size = 4>**Note:** The "sample data" in "Files" contains default files. Do not upload anything in here!

---
### **Making changes to the notebook**

<font size = 4>**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.

<font size = 4>To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).
You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment.

# **0. Before getting started**
---
<font size = 4> Deep-STORM is able to train on simulated dataset of SMLM data (see https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458 for more info). Here, we provide a simulator that will generate training dataset (section 3.1.b). A few parameters will allow you to match the simulation to your experimental data. Similarly to what is described in the paper, simulations obtained from ThunderSTORM can also be loaded here (section 3.1.a).

---
<font size = 4>**Important note**

<font size = 4>- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.

<font size = 4>- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.

<font size = 4>- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.
---

# **1. Install Deep-STORM and dependencies**
---


In [None]:
Notebook_version = '1.13.3'
Network = 'Deep-STORM'



from builtins import any as b_any

def get_requirements_path():
    # Store requirements file in 'contents' directory
    current_dir = os.getcwd()
    dir_count = current_dir.count('/') - 1
    path = '../' * (dir_count) + 'requirements.txt'
    return path

def filter_files(file_list, filter_list):
    filtered_list = []
    for fname in file_list:
        if b_any(fname.split('==')[0] in s for s in filter_list):
            filtered_list.append(fname)
    return filtered_list

def build_requirements_file(before, after):
    path = get_requirements_path()

    # Exporting requirements.txt for local run
    !pip freeze > $path

    # Get minimum requirements file
    df = pd.read_csv(path)
    mod_list = [m.split('.')[0] for m in after if not m in before]
    req_list_temp = df.values.tolist()
    req_list = [x[0] for x in req_list_temp]

    # Replace with package name and handle cases where import name is different to module name
    mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]
    mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]
    filtered_list = filter_files(req_list, mod_replace_list)

    file=open(path,'w')
    for item in filtered_list:
        file.writelines(item)

    file.close()

import sys
before = [str(m) for m in sys.modules]

#@markdown ##Install Deep-STORM and dependencies
# %% Model definition + helper functions

!pip install fpdf2
# Import keras modules and libraries
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization, Layer
from tensorflow.keras.callbacks import Callback
from tensorflow.keras import backend as K
from tensorflow.keras import optimizers, losses

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import ReduceLROnPlateau
from skimage.transform import warp
from skimage.transform import SimilarityTransform
from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio as psnr
from scipy.signal import fftconvolve

# Import common libraries
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import h5py
import scipy.io as sio
from os.path import abspath
from sklearn.model_selection import train_test_split
from skimage import io
import time
import os
import shutil
import csv
from PIL import Image
from PIL.TiffTags import TAGS
from scipy.ndimage import gaussian_filter
import math
from astropy.visualization import simple_norm
from sys import getsizeof
from fpdf import FPDF, HTMLMixin
from fpdf.enums import XPos, YPos
from pip._internal.operations.freeze import freeze
import subprocess
from datetime import datetime

# For sliders and dropdown menu, progress bar
from ipywidgets import interact
import ipywidgets as widgets
from tqdm import tqdm

# For Multi-threading in simulation
from numba import njit, prange

#Create a variable to get and store relative base path
base_path = os.getcwd()

# define a function that projects and rescales an image to the range [0,1]
def project_01(im):
    im = np.squeeze(im)
    min_val = im.min()
    max_val = im.max()
    return (im - min_val)/(max_val - min_val)

# normalize image given mean and std
def normalize_im(im, dmean, dstd):
    im = np.squeeze(im)
    im_norm = np.zeros(im.shape,dtype=np.float32)
    im_norm = (im - dmean)/dstd
    return im_norm

# Define the loss history recorder
class LossHistory(Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

#  Define a matlab like gaussian 2D filter
def matlab_style_gauss2D(shape=(7,7),sigma=1):
    """
    2D gaussian filter - should give the same result as:
    MATLAB's fspecial('gaussian',[shape],[sigma])
    """
    m,n = [(ss-1.)/2. for ss in shape]
    y,x = np.ogrid[-m:m+1,-n:n+1]
    h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
    h.astype(dtype=K.floatx())
    h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
    sumh = h.sum()
    if sumh != 0:
        h /= sumh
    h = h*2.0
    h = h.astype('float32')
    return h

# Expand the filter dimensions
psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)
gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])

# Combined MSE + L1 loss
def L1L2loss(input_shape):
    def bump_mse(heatmap_true, spikes_pred):

        # generate the heatmap corresponding to the predicted spikes
        heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')

        # heatmaps MSE
        loss_heatmaps = losses.mean_squared_error(heatmap_true,heatmap_pred)

        # l1 on the predicted spikes
        loss_spikes = losses.mean_absolute_error(spikes_pred,tf.zeros(input_shape))
        return loss_heatmaps + loss_spikes
    return bump_mse

# Define the concatenated conv2, batch normalization, and relu block
def conv_bn_relu(nb_filter, rk, ck, name):
    def f(input):
        conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\
                               padding="same", use_bias=False,\
                               kernel_initializer="Orthogonal",name='conv-'+name)(input)
        conv_norm = BatchNormalization(name='BN-'+name)(conv)
        conv_norm_relu = Activation(activation = "relu",name='Relu-'+name)(conv_norm)
        return conv_norm_relu
    return f

# Define the model architechture
def CNN(input,names):
    Features1 = conv_bn_relu(32,3,3,names+'F1')(input)
    pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)
    Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)
    Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)
    Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)
    up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)
    Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)
    up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)
    Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)
    up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)
    Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)
    return Features7

# Define the Model building for an arbitrary input size
def buildModel(input_dim, initial_learning_rate = 0.001):
    input_ = Input (shape = (input_dim))
    act_ = CNN (input_,'CNN')
    density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding="same",\
                                  activation="linear", use_bias = False,\
                                  kernel_initializer="Orthogonal",name='Prediction')(act_)
    model = Model (inputs= input_, outputs=density_pred)
    opt = optimizers.Adam(lr = initial_learning_rate)
    model.compile(optimizer=opt, loss = L1L2loss(input_dim))
    return model


# define a function that trains a model for a given data SNR and density
def train_model(patches, heatmaps, modelPath, epochs, steps_per_epoch, batch_size, upsampling_factor=8, validation_split = 0.3, initial_learning_rate = 0.001, pretrained_model_path = '', L2_weighting_factor = 100):

    """
    This function trains a CNN model on the desired training set, given the
    upsampled training images and labels generated in MATLAB.

    # Inputs
    # TO UPDATE ----------

    # Outputs
    function saves the weights of the trained model to a hdf5, and the
    normalization factors to a mat file. These will be loaded later for testing
    the model in test_model.
    """

    # for reproducibility
    np.random.seed(123)

    X_train, X_test, y_train, y_test = train_test_split(patches, heatmaps, test_size = validation_split, random_state=42)
    print('Number of training examples: %d' % X_train.shape[0])
    print('Number of validation examples: %d' % X_test.shape[0])

    # Setting type
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    y_train = y_train.astype('float32')
    y_test = y_test.astype('float32')


    #===================== Training set normalization ==========================
    # normalize training images to be in the range [0,1] and calculate the
    # training set mean and std
    mean_train = np.zeros(X_train.shape[0],dtype=np.float32)
    std_train = np.zeros(X_train.shape[0], dtype=np.float32)
    for i in range(X_train.shape[0]):
        X_train[i, :, :] = project_01(X_train[i, :, :])
        mean_train[i] = X_train[i, :, :].mean()
        std_train[i] = X_train[i, :, :].std()

    # resulting normalized training images
    mean_val_train = mean_train.mean()
    std_val_train = std_train.mean()
    X_train_norm = np.zeros(X_train.shape, dtype=np.float32)
    for i in range(X_train.shape[0]):
        X_train_norm[i, :, :] = normalize_im(X_train[i, :, :], mean_val_train, std_val_train)

    # patch size
    psize = X_train_norm.shape[1]

    # Reshaping
    X_train_norm = X_train_norm.reshape(X_train.shape[0], psize, psize, 1)

    # ===================== Test set normalization ==========================
    # normalize test images to be in the range [0,1] and calculate the test set
    # mean and std
    mean_test = np.zeros(X_test.shape[0],dtype=np.float32)
    std_test = np.zeros(X_test.shape[0], dtype=np.float32)
    for i in range(X_test.shape[0]):
        X_test[i, :, :] = project_01(X_test[i, :, :])
        mean_test[i] = X_test[i, :, :].mean()
        std_test[i] = X_test[i, :, :].std()

    # resulting normalized test images
    mean_val_test = mean_test.mean()
    std_val_test = std_test.mean()
    X_test_norm = np.zeros(X_test.shape, dtype=np.float32)
    for i in range(X_test.shape[0]):
        X_test_norm[i, :, :] = normalize_im(X_test[i, :, :], mean_val_test, std_val_test)

    # Reshaping
    X_test_norm = X_test_norm.reshape(X_test.shape[0], psize, psize, 1)

    # Reshaping labels
    Y_train = y_train.reshape(y_train.shape[0], psize, psize, 1)
    Y_test = y_test.reshape(y_test.shape[0], psize, psize, 1)

    # Save datasets to a matfile to open later in matlab
    mdict = {"mean_test": mean_val_test, "std_test": std_val_test, "upsampling_factor": upsampling_factor, "Normalization factor": L2_weighting_factor}
    sio.savemat(os.path.join(modelPath,"model_metadata.mat"), mdict)


    # Set the dimensions ordering according to tensorflow consensous
    # K.set_image_dim_ordering('tf')
    K.set_image_data_format('channels_last')

    # Save the model weights after each epoch if the validation loss decreased
    checkpointer = ModelCheckpoint(filepath=os.path.join(modelPath,"weights_best.hdf5"), verbose=1,
                                   save_best_only=True)

    # Change learning when loss reaches a plataeu
    change_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.00005)

    # Model building and complitation
    model = buildModel((psize, psize, 1), initial_learning_rate = initial_learning_rate)
    model.summary()

    # Load pretrained model
    if not pretrained_model_path:
      print('Using random initial model weights.')
    else:
      print('Loading model weights from '+pretrained_model_path)
      model.load_weights(pretrained_model_path)

    # Create an image data generator for real time data augmentation
    datagen = ImageDataGenerator(
        featurewise_center=False,  # set input mean to 0 over the dataset
        samplewise_center=False,  # set each sample mean to 0
        featurewise_std_normalization=False,  # divide inputs by std of the dataset
        samplewise_std_normalization=False,  # divide each input by its std
        zca_whitening=False,  # apply ZCA whitening
        rotation_range=0.,  # randomly rotate images in the range (degrees, 0 to 180)
        width_shift_range=0.,  # randomly shift images horizontally (fraction of total width)
        height_shift_range=0.,  # randomly shift images vertically (fraction of total height)
        zoom_range=0.,
        shear_range=0.,
        horizontal_flip=False,  # randomly flip images
        vertical_flip=False,  # randomly flip images
        fill_mode='constant',
        data_format=K.image_data_format())

    # Fit the image generator on the training data
    datagen.fit(X_train_norm)

    # loss history recorder
    history = LossHistory()

    # Inform user training begun
    print('-------------------------------')
    print('Training model...')

    # Fit model on the batches generated by datagen.flow()
    train_history = model.fit_generator(datagen.flow(X_train_norm, Y_train, batch_size=batch_size),
                                        steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=1,
                                        validation_data=(X_test_norm, Y_test),
                                        callbacks=[history, checkpointer, change_lr])

    # Inform user training ended
    print('-------------------------------')
    print('Training Complete!')

    # Save the last model
    model.save(os.path.join(modelPath, 'weights_last.hdf5'))

    # convert the history.history dict to a pandas DataFrame:
    lossData = pd.DataFrame(train_history.history)

    if os.path.exists(os.path.join(modelPath,"Quality Control")):
      shutil.rmtree(os.path.join(modelPath,"Quality Control"))

    os.makedirs(os.path.join(modelPath,"Quality Control"))

    # The training evaluation.csv is saved (overwrites the Files if needed).
    lossDataCSVpath = os.path.join(modelPath,"Quality Control/training_evaluation.csv")
    with open(lossDataCSVpath, 'w') as f:
      writer = csv.writer(f)
      writer.writerow(['loss','val_loss','learning rate'])
      for i in range(len(train_history.history['loss'])):
        writer.writerow([train_history.history['loss'][i], train_history.history['val_loss'][i], train_history.history['lr'][i]])

    return


# Normalization functions from Martin Weigert used in CARE
def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):
    """This function is adapted from Martin Weigert"""
    """Percentile-based image normalization."""

    mi = np.percentile(x,pmin,axis=axis,keepdims=True)
    ma = np.percentile(x,pmax,axis=axis,keepdims=True)
    return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)


def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32
    """This function is adapted from Martin Weigert"""
    if dtype is not None:
        x   = x.astype(dtype,copy=False)
        mi  = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)
        ma  = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)
        eps = dtype(eps)

    try:
        import numexpr
        x = numexpr.evaluate("(x - mi) / ( ma - mi + eps )")
    except ImportError:
        x =                   (x - mi) / ( ma - mi + eps )

    if clip:
        x = np.clip(x,0,1)

    return x

def norm_minmse(gt, x, normalize_gt=True):
    """This function is adapted from Martin Weigert"""

    """
    normalizes and affinely scales an image pair such that the MSE is minimized

    Parameters
    ----------
    gt: ndarray
        the ground truth image
    x: ndarray
        the image that will be affinely scaled
    normalize_gt: bool
        set to True of gt image should be normalized (default)
    Returns
    -------
    gt_scaled, x_scaled
    """
    if normalize_gt:
        gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)
    x = x.astype(np.float32, copy=False) - np.mean(x)
    #x = x - np.mean(x)
    gt = gt.astype(np.float32, copy=False) - np.mean(gt)
    #gt = gt - np.mean(gt)
    scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())
    return gt, scale * x


# Multi-threaded Erf-based image construction
@njit(parallel=True)
def FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = (64,64), pixel_size = 100):
  w = image_size[0]
  h = image_size[1]
  erfImage = np.zeros((w, h))
  for ij in prange(w*h):
    j = int(ij/w)
    i = ij - j*w
    for (xc, yc, photon, sigma) in zip(xc_array, yc_array, photon_array, sigma_array):
      # Don't bother if the emitter has photons <= 0 or if Sigma <= 0
      if (sigma > 0) and (photon > 0):
        S = sigma*math.sqrt(2)
        x = i*pixel_size - xc
        y = j*pixel_size - yc
        # Don't bother if the emitter is further than 4 sigma from the centre of the pixel
        if (x+pixel_size/2)**2 + (y+pixel_size/2)**2 < 16*sigma**2:
          ErfX = math.erf((x+pixel_size)/S) - math.erf(x/S)
          ErfY = math.erf((y+pixel_size)/S) - math.erf(y/S)
          erfImage[j][i] += 0.25*photon*ErfX*ErfY
  return erfImage


@njit(parallel=True)
def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = (64,64), pixel_size = 100):
  w = image_size[0]
  h = image_size[1]
  locImage = np.zeros((image_size[0],image_size[1]) )
  n_locs = len(xc_array)

  for e in prange(n_locs):
    locImage[int(max(min(round(yc_array[e]/pixel_size),w-1),0))][int(max(min(round(xc_array[e]/pixel_size),h-1),0))] += 1

  return locImage



def getPixelSizeTIFFmetadata(TIFFpath, display=False):
  with Image.open(TIFFpath) as img:
    meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}


  # TIFF tags
  # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml
  # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html
  ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution
  width = meta_dict['ImageWidth'][0]
  height = meta_dict['ImageLength'][0]

  xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit

  if len(xResolution) == 1:
    xResolution = xResolution[0]
  elif len(xResolution) == 2:
    xResolution = xResolution[0]/xResolution[1]
  else:
    print('Image resolution not defined.')
    xResolution = 1

  if ResolutionUnit == 2:
    # Units given are in inches
    pixel_size = 0.025*1e9/xResolution
  elif ResolutionUnit == 3:
    # Units given are in cm
    pixel_size = 0.01*1e9/xResolution
  else:
    # ResolutionUnit is therefore 1
    print('Resolution unit not defined. Assuming: um')
    pixel_size = 1e3/xResolution

  if display:
    print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')
    print('Image size: '+str(width)+'x'+str(height))

  return (pixel_size, width, height)


def saveAsTIF(path, filename, array, pixel_size):
  """
  Image saving using PIL to save as .tif format
  # Input
  path       - path where it will be saved
  filename   - name of the file to save (no extension)
  array      - numpy array conatining the data at the required format
  pixel_size - physical size of pixels in nanometers (identical for x and y)
  """

  # print('Data type: '+str(array.dtype))
  if (array.dtype == np.uint16):
    mode = 'I;16'
  elif (array.dtype == np.uint32):
    mode = 'I'
  else:
    mode = 'F'

  # Rounding the pixel size to the nearest number that divides exactly 1cm.
  # Resolution needs to be a rational number --> see TIFF format
  # pixel_size = 10000/(round(10000/pixel_size))

  if len(array.shape) == 2:
    im = Image.fromarray(array)
    im.save(os.path.join(path, filename+'.tif'),
                  mode = mode,
                  resolution_unit = 3,
                  resolution = 0.01*1e9/pixel_size)


  elif len(array.shape) == 3:
    imlist = []
    for frame in array:
      imlist.append(Image.fromarray(frame))

    imlist[0].save(os.path.join(path, filename+'.tif'), save_all=True,
                  append_images=imlist[1:],
                  mode = mode,
                  resolution_unit = 3,
                  resolution = 0.01*1e9/pixel_size)

  return




class Maximafinder(Layer):
    def __init__(self, thresh, neighborhood_size, use_local_avg, **kwargs):
        super(Maximafinder, self).__init__(**kwargs)
        self.thresh = tf.constant(thresh, dtype=tf.float32)
        self.nhood = neighborhood_size
        self.use_local_avg = use_local_avg

    def build(self, input_shape):
        if self.use_local_avg is True:
          self.kernel_x = tf.reshape(tf.constant([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=tf.float32), [3, 3, 1, 1])
          self.kernel_y = tf.reshape(tf.constant([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])
          self.kernel_sum = tf.reshape(tf.constant([[1,1,1],[1,1,1],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])

    def call(self, inputs):

        # local maxima positions
        max_pool_image = MaxPooling2D(pool_size=(self.nhood,self.nhood), strides=(1,1), padding='same')(inputs)
        cond = tf.math.greater(max_pool_image, self.thresh) & tf.math.equal(max_pool_image, inputs)
        indices = tf.where(cond)
        bind, xind, yind = indices[:, 0], indices[:, 2], indices[:, 1]
        confidence = tf.gather_nd(inputs, indices)

        # local CoG estimator
        if self.use_local_avg:
          x_image = K.conv2d(inputs, self.kernel_x, padding='same')
          y_image = K.conv2d(inputs, self.kernel_y, padding='same')
          sum_image = K.conv2d(inputs, self.kernel_sum, padding='same')
          confidence = tf.cast(tf.gather_nd(sum_image, indices), dtype=tf.float32)
          x_local = tf.math.divide(tf.gather_nd(x_image, indices),tf.gather_nd(sum_image, indices))
          y_local = tf.math.divide(tf.gather_nd(y_image, indices),tf.gather_nd(sum_image, indices))
          xind = tf.cast(xind, dtype=tf.float32) + tf.cast(x_local, dtype=tf.float32)
          yind = tf.cast(yind, dtype=tf.float32) + tf.cast(y_local, dtype=tf.float32)
        else:
          xind = tf.cast(xind, dtype=tf.float32)
          yind = tf.cast(yind, dtype=tf.float32)

        return bind, xind, yind, confidence

    def get_config(self):

        # Implement get_config to enable serialization. This is optional.
        base_config = super(Maximafinder, self).get_config()
        config = {}
        return dict(list(base_config.items()) + list(config.items()))



# ------------------------------- Prediction with postprocessing  function-------------------------------
def batchFramePredictionLocalization(dataPath, filename, modelPath, savePath, batch_size=1, thresh=0.1, neighborhood_size=3, use_local_avg = False, pixel_size = None):
    """
    This function tests a trained model on the desired test set, given the
    tiff stack of test images, learned weights, and normalization factors.

    # Inputs
    dataPath          - the path to the folder containing the tiff stack(s) to run prediction on
    filename          - the name of the file to process
    modelPath         - the path to the folder containing the weights file and the mean and standard deviation file generated in train_model
    savePath          - the path to the folder where to save the prediction
    batch_size.       - the number of frames to predict on for each iteration
    thresh            - threshoold percentage from the maximum of the gaussian scaling
    neighborhood_size - the size of the neighborhood for local maxima finding
    use_local_average - Boolean whether to perform local averaging or not
    """

    # load mean and std
    matfile = sio.loadmat(os.path.join(modelPath,'model_metadata.mat'))
    test_mean = np.array(matfile['mean_test'])
    test_std = np.array(matfile['std_test'])
    upsampling_factor = np.array(matfile['upsampling_factor'])
    upsampling_factor = int(upsampling_factor.item()) # convert to scalar
    L2_weighting_factor = np.array(matfile['Normalization factor'])
    L2_weighting_factor = L2_weighting_factor.item() # convert to scalar

    # Read in the raw file
    Images = io.imread(os.path.join(dataPath, filename))
    if pixel_size == None:
      pixel_size, _, _ = getPixelSizeTIFFmetadata(os.path.join(dataPath, filename), display=True)
    pixel_size_hr = pixel_size/upsampling_factor

    # get dataset dimensions
    (nFrames, M, N) = Images.shape
    print('Input image is '+str(N)+'x'+str(M)+' with '+str(nFrames)+' frames.')

    # Build the model for a bigger image
    model = buildModel((upsampling_factor*M, upsampling_factor*N, 1))

    # Load the trained weights
    model.load_weights(os.path.join(modelPath,'weights_best.hdf5'))

    # add a post-processing module
    max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, use_local_avg)

    # Initialise the results: lists will be used to collect all the localizations
    frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []

    # Initialise the results
    Prediction = np.zeros((M*upsampling_factor, N*upsampling_factor), dtype=np.float32)
    Widefield = np.zeros((M, N), dtype=np.float32)

    # run model in batches
    n_batches = math.ceil(nFrames/batch_size)
    for b in tqdm(range(n_batches)):

      nF = min(batch_size, nFrames - b*batch_size)
      Images_norm = np.zeros((nF, M, N),dtype=np.float32)
      Images_upsampled = np.zeros((nF, M*upsampling_factor, N*upsampling_factor), dtype=np.float32)

      # Upsampling using a simple nearest neighbor interp and calculating - MULTI-THREAD this?
      for f in range(nF):
        Images_norm[f,:,:] = project_01(Images[b*batch_size+f,:,:])
        Images_norm[f,:,:] = normalize_im(Images_norm[f,:,:], test_mean, test_std)
        Images_upsampled[f,:,:] = np.kron(Images_norm[f,:,:], np.ones((upsampling_factor,upsampling_factor)))
        Widefield += Images[b*batch_size+f,:,:]

      # Reshaping
      Images_upsampled = np.expand_dims(Images_upsampled,axis=3)

      # Run prediction and local amxima finding
      predicted_density = model.predict_on_batch(Images_upsampled)
      predicted_density[predicted_density < 0] = 0
      Prediction += predicted_density.sum(axis = 3).sum(axis = 0)

      bind, xind, yind, confidence = max_layer(predicted_density)

      # normalizing the confidence by the L2_weighting_factor
      confidence /= L2_weighting_factor

      # turn indices to nms and append to the results
      xind, yind = xind*pixel_size_hr, yind*pixel_size_hr
      frmind = (bind.numpy() + b*batch_size + 1).tolist()
      xind = xind.numpy().tolist()
      yind = yind.numpy().tolist()
      confidence = confidence.numpy().tolist()
      frame_number_list += frmind
      x_nm_list += xind
      y_nm_list += yind
      confidence_au_list += confidence

    # Open and create the csv file that will contain all the localizations
    if use_local_avg:
      ext = '_avg'
    else:
      ext = '_max'
    with open(os.path.join(savePath, 'Localizations_' + os.path.splitext(filename)[0] + ext + '.csv'), "w", newline='') as file:
      writer = csv.writer(file)
      writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])
      locs = list(zip(frame_number_list, x_nm_list, y_nm_list, confidence_au_list))
      writer.writerows(locs)

    # Save the prediction and widefield image
    Widefield = np.kron(Widefield, np.ones((upsampling_factor,upsampling_factor)))
    Widefield = np.float32(Widefield)

    # io.imsave(os.path.join(savePath, 'Predicted_'+os.path.splitext(filename)[0]+'.tif'), Prediction)
    # io.imsave(os.path.join(savePath, 'Widefield_'+os.path.splitext(filename)[0]+'.tif'), Widefield)

    saveAsTIF(savePath, 'Predicted_'+os.path.splitext(filename)[0], Prediction, pixel_size_hr)
    saveAsTIF(savePath, 'Widefield_'+os.path.splitext(filename)[0], Widefield, pixel_size_hr)


    return


# Colors for the warning messages
class bcolors:
  WARNING = '\033[31m'
  NORMAL = '\033[0m'  # white (normal)



def list_files(directory, extension):
  return (f for f in os.listdir(directory) if f.endswith('.' + extension))


# @njit(parallel=True)
def subPixelMaxLocalization(array, method = 'CoM', patch_size = 3):
  xMaxInd, yMaxInd = np.unravel_index(array.argmax(), array.shape, order='C')
  centralPatch = XC[(xMaxInd-patch_size):(xMaxInd+patch_size+1),(yMaxInd-patch_size):(yMaxInd+patch_size+1)]

  if (method == 'MAX'):
    x0 = xMaxInd
    y0 = yMaxInd

  elif (method == 'CoM'):
    x0 = 0
    y0 = 0
    S = 0
    for xy in range(patch_size*patch_size):
      y = math.floor(xy/patch_size)
      x = xy - y*patch_size
      x0 += x*array[x,y]
      y0 += y*array[x,y]
      S = array[x,y]

    x0 = x0/S - patch_size/2 + xMaxInd
    y0 = y0/S - patch_size/2 + yMaxInd

  elif (method == 'Radiality'):
    # Not implemented yet
    x0 = xMaxInd
    y0 = yMaxInd

  return (x0, y0)


@njit(parallel=True)
def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):
  n_locs = xc_array.shape[0]
  xc_array_Corr = np.empty(n_locs)
  yc_array_Corr = np.empty(n_locs)

  for loc in prange(n_locs):
    xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc]]
    yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc]]

  return (xc_array_Corr, yc_array_Corr)


print('--------------------------------')
print('DeepSTORM installation complete.')

# Check if this is the latest version of the notebook

All_notebook_versions = pd.read_csv("https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv", dtype=str)
print('Notebook version: '+Notebook_version)

Latest_Notebook_version = All_notebook_versions[All_notebook_versions["Notebook"] == Network]['Version'].iloc[0]
print('Latest notebook version: '+Latest_Notebook_version)


if Notebook_version == Latest_Notebook_version:
  print("This notebook is up-to-date.")
else:
  print(bcolors.WARNING +"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki")


# Latest_notebook_version = pd.read_csv("https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv")

# if Notebook_version == list(Latest_notebook_version.columns):
#   print("This notebook is up-to-date.")

# if not Notebook_version == list(Latest_notebook_version.columns):
#   print(bcolors.WARNING +"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki")


def pdf_export(trained = False, raw_data = False, pretrained_model = False):
  class MyFPDF(FPDF, HTMLMixin):
    pass

  pdf = MyFPDF()
  pdf.add_page()
  pdf.set_right_margin(-1)
  pdf.set_font("Arial", size = 11, style='B')


  #model_name = 'little_CARE_test'
  day = datetime.now()
  datetime_str = str(day)[0:10]

  Header = 'Training report for '+Network+' model ('+model_name+')\nDate: '+datetime_str
  pdf.multi_cell(180, 5, txt = Header, align = 'L')
  pdf.ln(1)

  # add another cell
  if trained:
    training_time = "Training time: "+str(hours)+ "hour(s) "+str(minutes)+"min(s) "+str(round(seconds))+"sec(s)"
    pdf.cell(190, 5, txt = training_time, new_x=XPos.LMARGIN, new_y=YPos.NEXT, align='L')
  pdf.ln(1)

  Header_2 = 'Information for your materials and method:'
  pdf.cell(190, 5, txt=Header_2,  new_x=XPos.LMARGIN, new_y=YPos.NEXT, align='L')

  all_packages = ''
  for requirement in freeze(local_only=True):
    all_packages = all_packages+requirement+', '
  #print(all_packages)

  #Main Packages
  main_packages = ''
  version_numbers = []
  for name in ['tensorflow','numpy','keras']:
    find_name=all_packages.find(name)
    main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '
    #Version numbers only here:
    version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])

  try:
    cuda_version = subprocess.run(["nvcc","--version"],stdout=subprocess.PIPE)
    cuda_version = cuda_version.stdout.decode('utf-8')
    cuda_version = cuda_version[cuda_version.find(', V')+3:-1]
  except:
    cuda_version = ' - No cuda found - '
  try:
    gpu_name = subprocess.run(["nvidia-smi"],stdout=subprocess.PIPE)
    gpu_name = gpu_name.stdout.decode('utf-8')
    gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]
  except:
    gpu_name = ' - No GPU found - '
  #print(cuda_version[cuda_version.find(', V')+3:-1])
  #print(gpu_name)
  if raw_data == True:
    shape = (M,N)
  else:
    shape = (int(FOV_size/pixel_size),int(FOV_size/pixel_size))
  #dataset_size = len(os.listdir(Training_source))

  text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'

  if pretrained_model:
    text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. The models was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'

  pdf.set_font('')
  pdf.set_font_size(10.)
  pdf.multi_cell(180, 5, txt = text, align='L')
  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font("Arial", size = 11, style='B')
  pdf.cell(190, 5, txt = 'Training dataset', align='L',  new_x=XPos.LMARGIN, new_y=YPos.NEXT)
  pdf.set_font('')
  pdf.set_font_size(10.)
  if raw_data==False:
    simul_text = 'The training dataset was created in the notebook using the following simulation settings:'
    pdf.cell(200, 5, txt=simul_text, align='L')
    pdf.ln(1)
    html = """
    <table width=60% style="margin-left:0px;">
      <tr>
        <th width = 50% align="left">Setting</th>
        <th width = 50% align="left">Simulated Value</th>
      </tr>
      <tr>
        <td width = 50%>FOV_size</td>
        <td width = 50%>{0}</td>
      </tr>
      <tr>
        <td width = 50%>pixel_size</td>
        <td width = 50%>{1}</td>
      </tr>
      <tr>
        <td width = 50%>ADC_per_photon_conversion</td>
        <td width = 50%>{2}</td>
      </tr>
      <tr>
        <td width = 50%>ReadOutNoise_ADC</td>
        <td width = 50%>{3}</td>
      </tr>
      <tr>
        <td width = 50%>ADC_offset</td>
        <td width = 50%>{4}</td>
      </tr>
      <tr>
        <td width = 50%>emitter_density</td>
        <td width = 50%>{5}</td>
      </tr>
      <tr>
        <td width = 50%>emitter_density_std</td>
        <td width = 50%>{6}</td>
      </tr>
      <tr>
        <td width = 50%>number_of_frames</td>
        <td width = 50%>{7}</td>
      </tr>
      <tr>
        <td width = 50%>sigma</td>
        <td width = 50%>{8}</td>
      </tr>
      <tr>
        <td width = 50%>sigma_std</td>
        <td width = 50%>{9}</td>
      </tr>
      <tr>
        <td width = 50%>n_photons</td>
        <td width = 50%>{10}</td>
      </tr>
      <tr>
        <td width = 50%>n_photons_std</td>
        <td width = 50%>{11}</td>
      </tr>
    </table>
    """.format(FOV_size, pixel_size, ADC_per_photon_conversion, ReadOutNoise_ADC, ADC_offset, emitter_density, emitter_density_std, number_of_frames, sigma, sigma_std, n_photons, n_photons_std)
    pdf.write_html(html)
  else:
    simul_text = 'The training dataset was simulated using ThunderSTORM and loaded into the notebook.'
    pdf.multi_cell(190, 5, txt=simul_text, align='L')
    pdf.ln(1)
    pdf.set_font("Arial", size = 11, style='B')
    #pdf.ln(1)
    #pdf.cell(190, 5, txt = 'Training Dataset', align='L',  new_x=XPos.LMARGIN, new_y=YPos.NEXT)
    pdf.set_font('')
    pdf.set_font('Arial', size = 10, style = 'B')
    pdf.cell(29, 5, txt= 'ImageData_path', align = 'L', new_x=XPos.RIGHT, new_y=YPos.TOP)
    pdf.set_font('')
    pdf.multi_cell(170, 5, txt = ImageData_path, align = 'L')
    pdf.ln(1)
    pdf.set_font('')
    pdf.set_font('Arial', size = 10, style = 'B')
    pdf.cell(28, 5, txt= 'LocalizationData_path:', align = 'L', new_x=XPos.RIGHT, new_y=YPos.TOP)
    pdf.set_font('')
    pdf.multi_cell(170, 5, txt = LocalizationData_path, align = 'L')
    pdf.ln(1)
    pdf.set_font('Arial', size = 10, style = 'B')
    pdf.cell(28, 5, txt= 'pixel_size:', align = 'L', new_x=XPos.RIGHT, new_y=YPos.TOP)
    pdf.set_font('')
    pdf.multi_cell(170, 5, txt = str(pixel_size), align = 'L')
    pdf.ln(1)
  #pdf.cell(190, 5, txt=aug_text, align='L',  new_x=XPos.LMARGIN, new_y=YPos.NEXT)
  pdf.set_font('Arial', size = 11, style = 'B')
  pdf.ln(1)
  pdf.cell(180, 5, txt = 'Parameters', align='L',  new_x=XPos.LMARGIN, new_y=YPos.NEXT)
  pdf.set_font('')
  pdf.set_font_size(10.)
  # if Use_Default_Advanced_Parameters:
  #   pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')
  pdf.cell(200, 5, txt='The following parameters were used to generate patches:')
  pdf.ln(1)
  html = """
  <table width=70% style="margin-left:0px;">
    <tr>
      <th width = 50% align="left">Training Parameter</th>
      <th width = 50% align="left">Value</th>
    </tr>
    <tr>
      <td width = 50%>patch_size</td>
      <td width = 50%>{0}</td>
    </tr>
    <tr>
      <td width = 50%>upsampling_factor</td>
      <td width = 50%>{1}</td>
    </tr>
    <tr>
      <td width = 50%>num_patches_per_frame</td>
      <td width = 50%>{2}</td>
    </tr>
    <tr>
      <td width = 50%>min_number_of_emitters_per_patch</td>
      <td width = 50%>{3}</td>
    </tr>
    <tr>
      <td width = 50%>max_num_patches</td>
      <td width = 50%>{4}</td>
    </tr>
    <tr>
      <td width = 50%>gaussian_sigma</td>
      <td width = 50%>{5}</td>
    </tr>
    <tr>
      <td width = 50%>Automatic_normalization</td>
      <td width = 50%>{6}</td>
    </tr>
    <tr>
      <td width = 50%>L2_weighting_factor</td>
      <td width = 50%>{7}</td>
    </tr>
  </table>
  """.format(str(patch_size)+'x'+str(patch_size), upsampling_factor, num_patches_per_frame, min_number_of_emitters_per_patch, max_num_patches, gaussian_sigma, Automatic_normalization, L2_weighting_factor)
  pdf.write_html(html)
  pdf.ln(3)
  pdf.set_font('Arial', size=10)
  pdf.cell(200, 5, txt='The following parameters were used for training:')
  pdf.ln(1)
  html = """
  <table width=70% style="margin-left:0px;">
    <tr>
      <th width = 50% align="left">Training Parameter</th>
      <th width = 50% align="left">Value</th>
    </tr>
    <tr>
      <td width = 50%>number_of_epochs</td>
      <td width = 50%>{0}</td>
    </tr>
    <tr>
      <td width = 50%>batch_size</td>
      <td width = 50%>{1}</td>
    </tr>
    <tr>
      <td width = 50%>number_of_steps</td>
      <td width = 50%>{2}</td>
    </tr>
    <tr>
      <td width = 50%>percentage_validation</td>
      <td width = 50%>{3}</td>
    </tr>
    <tr>
      <td width = 50%>initial_learning_rate</td>
      <td width = 50%>{4}</td>
    </tr>
  </table>
  """.format(number_of_epochs,batch_size,number_of_steps,percentage_validation,initial_learning_rate)
  pdf.write_html(html)

  pdf.ln(1)
  # pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.cell(21, 5, txt= 'Model Path:', align = 'L', new_x=XPos.RIGHT, new_y=YPos.TOP)
  pdf.set_font('')
  pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')
  pdf.ln(1)

  pdf.ln(1)
  pdf.cell(60, 5, txt = 'Example Training Images',  new_x=XPos.LMARGIN, new_y=YPos.NEXT)
  pdf.ln(1)
  exp_size = io.imread(base_path + '/TrainingDataExample_DeepSTORM2D.png').shape
  pdf.image(base_path + '/TrainingDataExample_DeepSTORM2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))
  pdf.ln(1)
  ref_1 = 'References:\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. "Democratising deep learning for microscopy with ZeroCostDL4Mic." Nature Communications (2021).'
  pdf.multi_cell(190, 5, txt = ref_1, align='L')
  pdf.ln(1)
  ref_2 = '- Deep-STORM: Nehme, Elias, et al. "Deep-STORM: super-resolution single-molecule microscopy by deep learning." Optica 5.4 (2018): 458-464.'
  pdf.multi_cell(190, 5, txt = ref_2, align='L')
  pdf.ln(1)
  # if Use_Data_augmentation:
  #   ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. "Augmentor: an image augmentation library for machine learning." arXiv preprint arXiv:1708.04680 (2017).'
  #   pdf.multi_cell(190, 5, txt = ref_3, align='L')
  pdf.ln(3)
  reminder = 'Important:\nRemember to perform the quality control step on all newly trained models\nPlease consider depositing your training dataset on Zenodo'
  pdf.set_font('Arial', size = 11, style='B')
  pdf.multi_cell(190, 5, txt=reminder, align='C')
  pdf.ln(1)

  pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')
  print('------------------------------')
  print('PDF report exported in '+model_path+'/'+model_name+'/')

def qc_pdf_export():
  class MyFPDF(FPDF, HTMLMixin):
    pass

  pdf = MyFPDF()
  pdf.add_page()
  pdf.set_right_margin(-1)
  pdf.set_font("Arial", size = 11, style='B')

  Network = 'Deep-STORM'
  #model_name = os.path.basename(full_QC_model_path)
  day = datetime.now()
  datetime_str = str(day)[0:10]

  Header = 'Quality Control report for '+Network+' model ('+os.path.basename(QC_model_path)+')\nDate: '+datetime_str
  pdf.multi_cell(180, 5, txt = Header, align = 'L')
  pdf.ln(1)

  all_packages = ''
  for requirement in freeze(local_only=True):
    all_packages = all_packages+requirement+', '

  pdf.set_font('')
  pdf.set_font('Arial', size = 11, style = 'B')
  pdf.ln(2)
  pdf.cell(190, 5, txt = 'Loss curves',  new_x=XPos.LMARGIN, new_y=YPos.NEXT, align='L')
  pdf.ln(1)
  if os.path.exists(savePath+'/lossCurvePlots.png'):
    exp_size = io.imread(savePath+'/lossCurvePlots.png').shape
    pdf.image(savePath+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))
  else:
    pdf.set_font('')
    pdf.set_font('Arial', size=10)
    pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')
  pdf.ln(2)
  pdf.set_font('')
  pdf.set_font('Arial', size = 10, style = 'B')
  pdf.ln(3)
  pdf.cell(80, 5, txt = 'Example Quality Control Visualisation',  new_x=XPos.LMARGIN, new_y=YPos.NEXT)
  pdf.ln(1)
  exp_size = io.imread(savePath+'/QC_example_data.png').shape
  pdf.image(savePath+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))
  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font('Arial', size = 11, style = 'B')
  pdf.ln(1)
  pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L',  new_x=XPos.LMARGIN, new_y=YPos.NEXT)
  pdf.set_font('')
  pdf.set_font_size(10.)

  pdf.ln(1)
  html = """
  <body>
  <font size="7" face="Courier" >
  <table width=94% style="margin-left:0px;">"""
  with open(savePath+'/'+os.path.basename(QC_model_path)+'_QC_metrics.csv', 'r') as csvfile:
    metrics = csv.reader(csvfile)
    header = next(metrics)
    image = header[0]
    mSSIM_PvsGT = header[1]
    mSSIM_SvsGT = header[2]
    NRMSE_PvsGT = header[3]
    NRMSE_SvsGT = header[4]
    PSNR_PvsGT = header[5]
    PSNR_SvsGT = header[6]
    header = """
    <tr>
    <th width = 10% align="left">{0}</th>
    <th width = 15% align="left">{1}</th>
    <th width = 15% align="center">{2}</th>
    <th width = 15% align="left">{3}</th>
    <th width = 15% align="center">{4}</th>
    <th width = 15% align="left">{5}</th>
    <th width = 15% align="center">{6}</th>
    </tr>""".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)
    html = html+header
    for row in metrics:
      image = row[0]
      mSSIM_PvsGT = row[1]
      mSSIM_SvsGT = row[2]
      NRMSE_PvsGT = row[3]
      NRMSE_SvsGT = row[4]
      PSNR_PvsGT = row[5]
      PSNR_SvsGT = row[6]
      cells = """
        <tr>
          <td width = 10% align="left">{0}</td>
          <td width = 15% align="center">{1}</td>
          <td width = 15% align="center">{2}</td>
          <td width = 15% align="center">{3}</td>
          <td width = 15% align="center">{4}</td>
          <td width = 15% align="center">{5}</td>
          <td width = 15% align="center">{6}</td>
        </tr>""".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))
      html = html+cells
    html = html+"""</body></table>"""

  pdf.write_html(html)

  pdf.ln(1)
  pdf.set_font('')
  pdf.set_font_size(10.)
  ref_1 = 'References:\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. "Democratising deep learning for microscopy with ZeroCostDL4Mic." Nature Communications (2021).'
  pdf.multi_cell(190, 5, txt = ref_1, align='L')
  pdf.ln(1)
  ref_2 = '- Deep-STORM: Nehme, Elias, et al. "Deep-STORM: super-resolution single-molecule microscopy by deep learning." Optica 5.4 (2018): 458-464.'
  pdf.multi_cell(190, 5, txt = ref_2, align='L')
  pdf.ln(1)

  pdf.ln(3)
  reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'

  pdf.set_font('Arial', size = 11, style='B')
  pdf.multi_cell(190, 5, txt=reminder, align='C')
  pdf.ln(1)

  pdf.output(savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')


  print('------------------------------')
  print('QC PDF report exported as '+savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')

# Build requirements file for local run
after = [str(m) for m in sys.modules]
build_requirements_file(before, after)

# **2. Complete the Colab session**
---


## **2.1. Check for GPU access**
---

By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelerator: GPU** *(Graphics processing unit)*


In [None]:
#@markdown ##Run this cell to check if you have GPU access
# %tensorflow_version 1.x

import tensorflow as tf
# if tf.__version__ != '2.2.0':
#   !pip install tensorflow==2.2.0

if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.')
  print('Did you change your runtime ?')
  print('If the runtime settings are correct then Google did not allocate GPU to your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi

# from tensorflow.python.client import device_lib
# device_lib.list_local_devices()

# print the tensorflow version
print('Tensorflow version is ' + str(tf.__version__))


## **2.2. Mount your Google Drive**
---
<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive.

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [None]:
#@markdown ##Run this cell to connect your Google Drive to Colab

#@markdown * Click on the URL.

#@markdown * Sign in your Google Account.

#@markdown * Copy the authorization code.

#@markdown * Enter the authorization code.

#@markdown * Click on "Files" site on the right. Refresh the site. Your Google Drive folder should now be available here as "drive".

#mounts user's Google Drive to Google Colab.

from google.colab import drive
drive.mount('/content/gdrive')


# **3. Generate patches for training**
---

For Deep-STORM the training data can be obtained in two ways:
* Simulated using ThunderSTORM or other simulation tool and loaded here (**using Section 3.1.a**)
* Directly simulated in this notebook (**using Section 3.1.b**)


## **3.1. - a) Load training data**
---

Here you can load your simulated data along with its corresponding localization file.
*   The `pixel_size` is defined in nanometer (nm).

In [None]:
#@markdown ##Load raw data

load_raw_data = True

# Get user input
ImageData_path = "" #@param {type:"string"}
LocalizationData_path = "" #@param {type: "string"}
#@markdown Get pixel size from file?
get_pixel_size_from_file = True #@param {type:"boolean"}
#@markdown Otherwise, use this value:
pixel_size = 100 #@param {type:"number"}

if get_pixel_size_from_file:
  pixel_size,_,_ = getPixelSizeTIFFmetadata(ImageData_path, True)

# load the tiff data
Images = io.imread(ImageData_path)
# get dataset dimensions
if len(Images.shape) == 3:
  (number_of_frames, M, N) = Images.shape
elif len(Images.shape) == 2:
  (M, N) = Images.shape
  number_of_frames = 1
print('Loaded images: '+str(M)+'x'+str(N)+' with '+str(number_of_frames)+' frames')

# Interactive display of the stack
def scroll_in_time(frame):
    f=plt.figure(figsize=(6,6))
    plt.imshow(Images[frame-1], interpolation='nearest', cmap = 'gray')
    plt.title('Training source at frame = ' + str(frame))
    plt.axis('off');

if number_of_frames > 1:
  interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));
else:
  f=plt.figure(figsize=(6,6))
  plt.imshow(Images, interpolation='nearest', cmap = 'gray')
  plt.title('Training source')
  plt.axis('off');

# Load the localization file and display the first
LocData = pd.read_csv(LocalizationData_path, index_col=0)
LocData.tail()



## **3.1. - b) Simulate training data**
---
This simulation tool allows you to generate SMLM data of randomly distrubuted emitters in a field-of-view.
The assumptions are as follows:

*   Gaussian Point Spread Function (PSF) with standard deviation defined by `Sigma`. The nominal value of `sigma` can be evaluated using `sigma = 0.21 x Lambda / NA`. (from [Zhang *et al.*, Applied Optics 2007](https://doi.org/10.1364/AO.46.001819))
*   Each emitter will emit `n_photons` per frame, and generate their equivalent Poisson noise.
*   The camera will contribute Gaussian noise to the signal with a standard deviation defined by `ReadOutNoise_ADC` in ADC
*   The `emitter_density` is defined as the number of emitters / um^2 on any given frame. Variability in the emitter density can be applied by adjusting `emitter_density_std`. The latter parameter represents the standard deviation of the normal distribution that the density is drawn from for each individual frame. `emitter_density` **is defined in number of emitters / um^2**.
*   The `n_photons` and `sigma` can additionally include some Gaussian variability by setting `n_photons_std` and `sigma_std`.

Important note:
- All dimensions are in nanometer (e.g. `FOV_size` = 6400 represents a field of view of 6.4 um x 6.4 um).



In [None]:
load_raw_data = False

# ---------------------------- User input ----------------------------
#@markdown Run the simulation
#@markdown ---
#@markdown Camera settings:
FOV_size =  6400#@param {type:"number"}
pixel_size =  100#@param {type:"number"}
ADC_per_photon_conversion = 1 #@param {type:"number"}
ReadOutNoise_ADC =  4.5#@param {type:"number"}
ADC_offset =  50#@param {type:"number"}

#@markdown Acquisition settings:
emitter_density =  6#@param {type:"number"}
emitter_density_std =  0#@param {type:"number"}

number_of_frames =  20#@param {type:"integer"}

sigma = 110 #@param {type:"number"}
sigma_std = 5 #@param {type:"number"}
# NA =  1.1 #@param {type:"number"}
# wavelength =  800#@param {type:"number"}
# wavelength_std =  150#@param {type:"number"}
n_photons =  2250#@param {type:"number"}
n_photons_std =  250#@param {type:"number"}


# ---------------------------- Variable initialisation ----------------------------
# Start the clock to measure how long it takes
start = time.time()

print('-----------------------------------------------------------')
n_molecules = emitter_density*FOV_size*FOV_size/10**6
n_molecules_std = emitter_density_std*FOV_size*FOV_size/10**6
print('Number of molecules / FOV: '+str(round(n_molecules,2))+' +/- '+str((round(n_molecules_std,2))))

# sigma = 0.21*wavelength/NA
# sigma_std = 0.21*wavelength_std/NA
# print('Gaussian PSF sigma: '+str(round(sigma,2))+' +/- '+str(round(sigma_std,2))+' nm')

M = round(FOV_size/pixel_size)
N = round(FOV_size/pixel_size)

FOV_size = M*pixel_size
print('Final image size: '+str(M)+'x'+str(M)+' ('+str(round(FOV_size/1000, 3))+'um x'+str(round(FOV_size/1000,3))+' um)')

np.random.seed(1)
display_upsampling = 8 # used to display the loc map here
NoiseFreeImages = np.zeros((number_of_frames, M, M))
locImage = np.zeros((number_of_frames, display_upsampling*M, display_upsampling*N))

frames = []
all_xloc = []
all_yloc = []
all_photons = []
all_sigmas = []

# ---------------------------- Main simulation loop ----------------------------
print('-----------------------------------------------------------')
for f in tqdm(range(number_of_frames)):

  # Define the coordinates of emitters by randomly distributing them across the FOV
  n_mol = int(max(round(np.random.normal(n_molecules, n_molecules_std, size=1)[0]), 0))
  x_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)
  y_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)
  photon_array = np.random.normal(n_photons, n_photons_std, size=n_mol)
  sigma_array = np.random.normal(sigma, sigma_std, size=n_mol)
  # x_c = np.linspace(0,3000,5)
  # y_c = np.linspace(0,3000,5)

  all_xloc += x_c.tolist()
  all_yloc += y_c.tolist()
  frames += ((f+1)*np.ones(x_c.shape[0])).tolist()
  all_photons += photon_array.tolist()
  all_sigmas += sigma_array.tolist()

  locImage[f] = FromLoc2Image_SimpleHistogram(x_c, y_c, image_size = (N*display_upsampling, M*display_upsampling), pixel_size = pixel_size/display_upsampling)

  # # Get the approximated locations according to the grid pixel size
  # Chr_emitters = [int(max(min(round(display_upsampling*x_c[i]/pixel_size),N*display_upsampling-1),0)) for i in range(len(x_c))]
  # Rhr_emitters = [int(max(min(round(display_upsampling*y_c[i]/pixel_size),M*display_upsampling-1),0)) for i in range(len(y_c))]

  # # Build Localization image
  # for (r,c) in zip(Rhr_emitters, Chr_emitters):
  #   locImage[f][r][c] += 1

  NoiseFreeImages[f] = FromLoc2Image_Erf(x_c, y_c, photon_array, sigma_array, image_size = (M,M), pixel_size = pixel_size)


# ---------------------------- Create DataFrame fof localization file ----------------------------
# Table with localization info as dataframe output
LocData = pd.DataFrame()
LocData["frame"] = frames
LocData["x [nm]"] = all_xloc
LocData["y [nm]"] = all_yloc
LocData["Photon #"] = all_photons
LocData["Sigma [nm]"] = all_sigmas
LocData.index += 1  # set indices to start at 1 and not 0 (same as ThunderSTORM)


# ---------------------------- Estimation of SNR ----------------------------
n_frames_for_SNR = 100
M_SNR = 10
x_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)
y_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)
photon_array = np.random.normal(n_photons, n_photons_std, size=n_frames_for_SNR)
sigma_array = np.random.normal(sigma, sigma_std, size=n_frames_for_SNR)

SNR = np.zeros(n_frames_for_SNR)
for i in range(n_frames_for_SNR):
  SingleEmitterImage = FromLoc2Image_Erf(np.array([x_c[i]]), np.array([x_c[i]]), np.array([photon_array[i]]), np.array([sigma_array[i]]), (M_SNR, M_SNR), pixel_size)
  Signal_photon = np.max(SingleEmitterImage)
  Noise_photon = math.sqrt((ReadOutNoise_ADC/ADC_per_photon_conversion)**2 + Signal_photon)
  SNR[i] = Signal_photon/Noise_photon

print('SNR: '+str(round(np.mean(SNR),2))+' +/- '+str(round(np.std(SNR),2)))
# ---------------------------- ----------------------------


# Table with info
simParameters = pd.DataFrame()
simParameters["FOV size (nm)"] = [FOV_size]
simParameters["Pixel size (nm)"] = [pixel_size]
simParameters["ADC/photon"] = [ADC_per_photon_conversion]
simParameters["Read-out noise (ADC)"] = [ReadOutNoise_ADC]
simParameters["Constant offset (ADC)"] = [ADC_offset]

simParameters["Emitter density (emitters/um^2)"] = [emitter_density]
simParameters["STD of emitter density (emitters/um^2)"] = [emitter_density_std]
simParameters["Number of frames"] = [number_of_frames]
# simParameters["NA"] = [NA]
# simParameters["Wavelength (nm)"] = [wavelength]
# simParameters["STD of wavelength (nm)"] = [wavelength_std]
simParameters["Sigma (nm))"] = [sigma]
simParameters["STD of Sigma (nm))"] = [sigma_std]
simParameters["Number of photons"] = [n_photons]
simParameters["STD of number of photons"] = [n_photons_std]
simParameters["SNR"] = [np.mean(SNR)]
simParameters["STD of SNR"] = [np.std(SNR)]


# ---------------------------- Finish simulation ----------------------------
# Calculating the noisy image
Images = ADC_per_photon_conversion * np.random.poisson(NoiseFreeImages) + ReadOutNoise_ADC * np.random.normal(size = (number_of_frames, M, N)) + ADC_offset
Images[Images <= 0] = 0

# Convert to 16-bit or 32-bits integers
if Images.max() < (2**16-1):
  Images = Images.astype(np.uint16)
else:
  Images = Images.astype(np.uint32)


# ---------------------------- Display ----------------------------
# Displaying the time elapsed for simulation
dt = time.time() - start
minutes, seconds = divmod(dt, 60)
hours, minutes = divmod(minutes, 60)
print("Time elapsed:",hours, "hour(s)",minutes,"min(s)",round(seconds,1),"sec(s)")


# Interactively display the results using Widgets
def scroll_in_time(frame):
  f = plt.figure(figsize=(18,6))
  plt.subplot(1,3,1)
  plt.imshow(locImage[frame-1], interpolation='bilinear', vmin = 0, vmax=0.1)
  plt.title('Localization image')
  plt.axis('off');

  plt.subplot(1,3,2)
  plt.imshow(NoiseFreeImages[frame-1], interpolation='nearest', cmap='gray')
  plt.title('Noise-free simulation')
  plt.axis('off');

  plt.subplot(1,3,3)
  plt.imshow(Images[frame-1], interpolation='nearest', cmap='gray')
  plt.title('Noisy simulation')
  plt.axis('off');

interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));

# Display the head of the dataframe with localizations
LocData.tail()


In [None]:
#@markdown ---
#@markdown ##Play this cell to save the simulated stack
#@markdown Please select a path to the folder where to save the simulated data. It is not necessary to save the data to run the training, but keeping the simulated for your own record can be useful to check its validity.
Save_path = "" #@param {type:"string"}

if not os.path.exists(Save_path):
  os.makedirs(Save_path)
  print('Folder created.')
else:
  print('Training data already exists in folder: Data overwritten.')

saveAsTIF(Save_path, 'SimulatedDataset', Images, pixel_size)
# io.imsave(os.path.join(Save_path, 'SimulatedDataset.tif'),Images)
LocData.to_csv(os.path.join(Save_path, 'SimulatedDataset.csv'))
simParameters.to_csv(os.path.join(Save_path, 'SimulatedParameters.csv'))
print('Training dataset saved.')

## **3.2. Generate training patches**
---

Training patches need to be created from the training data generated above.
*   The `patch_size` needs to give sufficient contextual information and for most cases a `patch_size` of 26 (corresponding to patches of 26x26 pixels) works fine. **DEFAULT: 26**
*   The `upsampling_factor` defines the effective magnification of the final super-resolved image compared to the input image (this is called magnification in ThunderSTORM). This is used to generate the super-resolved patches as target dataset. Using an `upsampling_factor` of 16 will require the use of more memory and it may be necessary to decreae the `patch_size` to 16 for example. **DEFAULT: 8**
*   The `num_patches_per_frame` defines the number of patches extracted from each frame generated in section 3.1. **DEFAULT: 500**
*   The `min_number_of_emitters_per_patch` defines the minimum number of emitters that need to be present in the patch to be a valid patch. An empty patch does not contain useful information for the network to learn from. **DEFAULT: 7**
*   The `max_num_patches` defines the maximum number of patches to generate. Fewer may be generated depending on how many pacthes are rejected and how many frames are available. **DEFAULT: 10000**
*   The `gaussian_sigma` defines the Gaussian standard deviation (in magnified pixels) applied to generate the super-resolved target image. **DEFAULT: 1**
*   The `L2_weighting_factor` is a normalization factor used in the loss function. It helps balancing the loss from the L2 norm. When using higher densities, this factor should be decreased and vice-versa. This factor can be autimatically calculated using an empiraical formula. **DEFAULT: 100**



In [None]:
#@markdown ## **Provide patch parameters**


# -------------------- User input --------------------
patch_size = 26 #@param {type:"integer"}
upsampling_factor = 8 #@param ["4", "8", "16"] {type:"raw"}
num_patches_per_frame =  500#@param {type:"integer"}
min_number_of_emitters_per_patch = 7#@param {type:"integer"}
max_num_patches =  10000#@param {type:"integer"}
gaussian_sigma = 1#@param {type:"integer"}

#@markdown Estimate the optimal normalization factor automatically?
Automatic_normalization = True #@param {type:"boolean"}
#@markdown Otherwise, it will use the following value:
L2_weighting_factor = 100 #@param {type:"number"}


# -------------------- Prepare variables --------------------
# Start the clock to measure how long it takes
start = time.time()

# Initialize some parameters
pixel_size_hr = pixel_size/int(upsampling_factor) # in nm
n_patches = min(int(number_of_frames)*int(num_patches_per_frame), int(max_num_patches))
patch_size = int(patch_size)*int(upsampling_factor)

# Dimensions of the high-res grid
Mhr = int(upsampling_factor)*M # in pixels
Nhr = int(upsampling_factor)*N # in pixels

# Initialize the training patches and labels
patches = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)
spikes = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)
heatmaps = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)

# Run over all frames and construct the training examples
k = 1 # current patch count
skip_counter = 0 # number of dataset skipped due to low density
id_start = 0 # id position in LocData for current frame
print('Generating '+str(n_patches)+' patches of '+str(patch_size)+'x'+str(patch_size))

n_locs = len(LocData.index)
print('Total number of localizations: '+str(n_locs))
density = n_locs/(M*N*number_of_frames*(0.001*pixel_size)**2)
print('Density: '+str(round(density,2))+' locs/um^2')
n_locs_per_patch = patch_size**2*density

if Automatic_normalization:
  # This empirical formulae attempts to balance the loss L2 function between the background and the bright spikes
  # A value of 100 was originally chosen to balance L2 for a patch size of 2.6x2.6^2 0.1um pixel size and density of 3 (hence the 20.28), at upsampling_factor = 8
  L2_weighting_factor = 100/math.sqrt(min(n_locs_per_patch, min_number_of_emitters_per_patch)*8**2/(int(upsampling_factor)**2*20.28))
  print('Normalization factor: '+str(round(L2_weighting_factor,2)))

# -------------------- Patch generation loop --------------------

print('-----------------------------------------------------------')
for (f, thisFrame) in enumerate(tqdm(Images)):

  # Upsample the frame
  upsampledFrame = np.kron(thisFrame, np.ones((int(upsampling_factor),int(upsampling_factor))))
  # Read all the provided high-resolution locations for current frame
  DataFrame = LocData[LocData['frame'] == f+1].copy()

  # Get the approximated locations according to the high-res grid pixel size
  Chr_emitters = [int(max(min(round(DataFrame['x [nm]'][i]/pixel_size_hr),Nhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]
  Rhr_emitters = [int(max(min(round(DataFrame['y [nm]'][i]/pixel_size_hr),Mhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]
  id_start += len(DataFrame.index)

  # Build Localization image
  LocImage = np.zeros((Mhr,Nhr))
  LocImage[(Rhr_emitters, Chr_emitters)] = 1

  # Here, there's a choice between the original Gaussian (classification approach) and using the erf function
  HeatMapImage = L2_weighting_factor*gaussian_filter(LocImage, float(gaussian_sigma))
  # HeatMapImage = L2_weighting_factor*FromLoc2Image_MultiThreaded(np.array(list(DataFrame['x [nm]'])), np.array(list(DataFrame['y [nm]'])),
                                                            #  np.ones(len(DataFrame.index)), pixel_size_hr*gaussian_sigma*np.ones(len(DataFrame.index)),
                                                            #  Mhr, pixel_size_hr)


  # Generate random position for the top left corner of the patch
  xc = np.random.randint(0, Mhr-patch_size, size=num_patches_per_frame)
  yc = np.random.randint(0, Nhr-patch_size, size=num_patches_per_frame)

  for c in range(len(xc)):
    if LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size].sum() < min_number_of_emitters_per_patch:
      skip_counter += 1
      continue

    else:
        # Limit maximal number of training examples to 15k
      if k > max_num_patches:
        break
      else:
        # Assign the patches to the right part of the images
        patches[k-1] = upsampledFrame[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]
        spikes[k-1] = LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]
        heatmaps[k-1] = HeatMapImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]
        k += 1 # increment current patch count

# Remove the empty data
patches = patches[:k-1]
spikes = spikes[:k-1]
heatmaps = heatmaps[:k-1]
n_patches = k-1

# -------------------- Failsafe --------------------
# Check if the size of the training set is smaller than 5k to notify user to simulate more images using ThunderSTORM
if ((k-1) < 5000):
  # W  = '\033[0m'  # white (normal)
  # R  = '\033[31m' # red
  print(bcolors.WARNING+'!! WARNING: Training set size is below 5K - Consider simulating more images in ThunderSTORM. !!'+bcolors.NORMAL)



# -------------------- Displays --------------------
print('Number of patches skipped due to low density: '+str(skip_counter))
# dataSize = int((getsizeof(patches)+getsizeof(heatmaps)+getsizeof(spikes))/(1024*1024)) #rounded in MB
# print('Size of patches: '+str(dataSize)+' MB')
print(str(n_patches)+' patches were generated.')

# Displaying the time elapsed for training
dt = time.time() - start
minutes, seconds = divmod(dt, 60)
hours, minutes = divmod(minutes, 60)
print("Time elapsed:",hours, "hour(s)",minutes,"min(s)",round(seconds),"sec(s)")

# Display patches interactively with a slider
def scroll_patches(patch):
  f = plt.figure(figsize=(16,6))
  plt.subplot(1,3,1)
  plt.imshow(patches[patch-1], interpolation='nearest', cmap='gray')
  plt.title('Raw data (frame #'+str(patch)+')')
  plt.axis('off');

  plt.subplot(1,3,2)
  plt.imshow(heatmaps[patch-1], interpolation='nearest')
  plt.title('Heat map')
  plt.axis('off');

  plt.subplot(1,3,3)
  plt.imshow(spikes[patch-1], interpolation='nearest')
  plt.title('Localization map')
  plt.axis('off');

  plt.savefig(base_path + '/TrainingDataExample_DeepSTORM2D.png',bbox_inches='tight',pad_inches=0)


interact(scroll_patches, patch=widgets.IntSlider(min=1, max=patches.shape[0], step=1, value=0, continuous_update=False));




# **4. Train the network**
---

## **4.1. Select your paths and parameters**

---

<font size = 4>**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).

<font size = 4>**`model_name`:** Use only my_model -style, not my-model (Use "_" not "-"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.


<font size = 5>**Training parameters**

<font size = 4>**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for ~100 epochs. Evaluate the performance after training (see 5). **Default value: 80**

<font size =4>**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**

<font size = 4>**`number_of_steps`:** Define the number of training steps by epoch. **If this value is set to 0**, by default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**

<font size = 4>**`percentage_validation`:**  Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 30**

<font size = 4>**`initial_learning_rate`:** This parameter represents the initial value to be used as learning rate in the optimizer. **Default value: 0.001**

In [None]:
#@markdown ###Path to training images and parameters

model_path = "" #@param {type: "string"}
model_name = "" #@param {type: "string"}
number_of_epochs =  80#@param {type:"integer"}
batch_size =  16#@param {type:"integer"}

number_of_steps =  0#@param {type:"integer"}
percentage_validation = 30 #@param {type:"number"}
initial_learning_rate = 0.001 #@param {type:"number"}


percentage_validation /= 100
if number_of_steps == 0:
  number_of_steps = int((1-percentage_validation)*n_patches/batch_size)
  print('Number of steps: '+str(number_of_steps))

# Pretrained model path initialised here so next cell does not need to be run
h5_file_path = ''
Use_pretrained_model = False

if not ('patches' in globals()):
  # W  = '\033[0m'  # white (normal)
  # R  = '\033[31m' # red
  print(bcolors.WARNING+'!! WARNING: No patches were found in memory currently. !!')

Save_path = os.path.join(model_path, model_name)
if os.path.exists(Save_path):
  print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)

print('-----------------------------')
print('Training parameters set.')



## **4.2. Using weights from a pre-trained model as initial weights**
---
<font size = 4>  Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deep-STORM 2D model**.

<font size = 4> This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.

<font size = 4> In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used.

In [None]:
# @markdown ##Loading weights from a pre-trained network

Use_pretrained_model = False #@param {type:"boolean"}
pretrained_model_choice = "Model_from_file" #@param ["Model_from_file"]
Weights_choice = "best" #@param ["last", "best"]

#@markdown ###If you chose "Model_from_file", please provide the path to the model folder:
pretrained_model_path = "" #@param {type:"string"}

# --------------------- Check if we load a previously trained model ------------------------
if Use_pretrained_model:

# --------------------- Load the model from the choosen path ------------------------
  if pretrained_model_choice == "Model_from_file":
    h5_file_path = os.path.join(pretrained_model_path, "weights_"+Weights_choice+".hdf5")

# --------------------- Download the a model provided in the XXX ------------------------

  if pretrained_model_choice == "Model_name":
    pretrained_model_name = "Model_name"
    pretrained_model_path = base_path + "/" + pretrained_model_name
    print("Downloading the 2D_Demo_Model_from_Stardist_2D_paper")
    if os.path.exists(pretrained_model_path):
      shutil.rmtree(pretrained_model_path)
    os.makedirs(pretrained_model_path)
    wget.download("", pretrained_model_path)
    wget.download("", pretrained_model_path)
    wget.download("", pretrained_model_path)
    wget.download("", pretrained_model_path)
    h5_file_path = os.path.join(pretrained_model_path, "weights_"+Weights_choice+".hdf5")

# --------------------- Add additional pre-trained models here ------------------------



# --------------------- Check the model exist ------------------------
# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled,
  if not os.path.exists(h5_file_path):
    print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.hdf5 pretrained model does not exist'+bcolors.NORMAL)
    Use_pretrained_model = False


# If the model path contains a pretrain model, we load the training rate,
  if os.path.exists(h5_file_path):
#Here we check if the learning rate can be loaded from the quality control folder
    if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):
      with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:
        csvRead = pd.read_csv(csvfile, sep=',')
        #print(csvRead)
        if "learning rate" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)
          print("pretrained network learning rate found")
          #find the last learning rate
          lastLearningRate = csvRead["learning rate"].iloc[-1]
          #Find the learning rate corresponding to the lowest validation loss
          min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]
          #print(min_val_loss)
          bestLearningRate = min_val_loss['learning rate'].iloc[-1]
          if Weights_choice == "last":
            print('Last learning rate: '+str(lastLearningRate))
          if Weights_choice == "best":
            print('Learning rate of best validation loss: '+str(bestLearningRate))
        if not "learning rate" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead
          bestLearningRate = initial_learning_rate
          lastLearningRate = initial_learning_rate
          print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead.'+bcolors.NORMAL)

#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used
    if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):
      print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+bcolors.NORMAL)
      bestLearningRate = initial_learning_rate
      lastLearningRate = initial_learning_rate


# Display info about the pretrained model to be loaded (or not)
if Use_pretrained_model:
  print('Weights found in:')
  print(h5_file_path)
  print('will be loaded prior to training.')

else:
  print('No pretrained network will be used.')
  h5_file_path = ''



## **4.4. Start Training**
---
<font size = 4>When playing the cell below you should see updates after each epoch (round). Network training can take some time.

<font size = 4>* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.

<font size = 4>Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.

In [None]:
#@markdown ##Start training

# Start the clock to measure how long it takes
start = time.time()

# --------------------- Using pretrained model ------------------------
#Here we ensure that the learning rate set correctly when using pre-trained models
if Use_pretrained_model:
  if Weights_choice == "last":
    initial_learning_rate = lastLearningRate

  if Weights_choice == "best":
    initial_learning_rate = bestLearningRate
# --------------------- ---------------------- ------------------------


#here we check that no model with the same name already exist, if so delete
if os.path.exists(Save_path):
  shutil.rmtree(Save_path)

# Create the model folder!
os.makedirs(Save_path)

# Export pdf summary
pdf_export(raw_data = load_raw_data, pretrained_model = Use_pretrained_model)

# Let's go !
train_model(patches, heatmaps, Save_path,
            steps_per_epoch=number_of_steps, epochs=number_of_epochs, batch_size=batch_size,
            upsampling_factor = upsampling_factor,
            validation_split = percentage_validation,
            initial_learning_rate = initial_learning_rate,
            pretrained_model_path = h5_file_path,
            L2_weighting_factor = L2_weighting_factor)

# # Show info about the GPU memory useage
# !nvidia-smi

# Displaying the time elapsed for training
dt = time.time() - start
minutes, seconds = divmod(dt, 60)
hours, minutes = divmod(minutes, 60)
print("Time elapsed:",hours, "hour(s)",minutes,"min(s)",round(seconds),"sec(s)")

# export pdf after training to update the existing document
pdf_export(trained = True, raw_data = load_raw_data, pretrained_model = Use_pretrained_model)


# **5. Evaluate your model**
---

<font size = 4>This section allows the user to perform important quality checks on the validity and generalisability of the trained model.

<font size = 4>**We highly recommend to perform quality control on all newly trained models.**

In [None]:
# model name and path
#@markdown ###Do you want to assess the model you just trained ?
Use_the_current_trained_model = True #@param {type:"boolean"}

#@markdown ###If not, please provide the path to the model folder:
#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter `model_name` (see section 4.1). Provide the name of this folder as `QC_model_path` .

QC_model_path = "" #@param {type:"string"}

if (Use_the_current_trained_model):
  QC_model_path = os.path.join(model_path, model_name)

if os.path.exists(QC_model_path):
  print("The "+os.path.basename(QC_model_path)+" model will be evaluated")
else:
  print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)
  print('Please make sure you provide a valid model path before proceeding further.')


## **5.1. Inspection of the loss function**
---

<font size = 4>First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*

<font size = 4>**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.

<font size = 4>**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.

<font size = 4>During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.

<font size = 4>Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.

In [None]:
#@markdown ##Play the cell to show a plot of training errors vs. epoch number


lossDataFromCSV = []
vallossDataFromCSV = []

with open(os.path.join(QC_model_path,'Quality Control/training_evaluation.csv'),'r') as csvfile:
    csvRead = csv.reader(csvfile, delimiter=',')
    next(csvRead)
    for row in csvRead:
      if row:
        lossDataFromCSV.append(float(row[0]))
        vallossDataFromCSV.append(float(row[1]))

epochNumber = range(len(lossDataFromCSV))
plt.figure(figsize=(15,10))

plt.subplot(2,1,1)
plt.plot(epochNumber,lossDataFromCSV, label='Training loss')
plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')
plt.title('Training loss and validation loss vs. epoch number (linear scale)')
plt.ylabel('Loss')
plt.xlabel('Epoch number')
plt.legend()

plt.subplot(2,1,2)
plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')
plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')
plt.title('Training loss and validation loss vs. epoch number (log scale)')
plt.ylabel('Loss')
plt.xlabel('Epoch number')
plt.legend()
plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'), bbox_inches='tight', pad_inches=0)
plt.show()



## **5.2. Error mapping and quality metrics estimation**
---

<font size = 4>This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the "QC_image_folder" using teh corresponding localization data contained in "QC_loc_folder" !

<font size = 4>**1. The SSIM (structural similarity) map**

<font size = 4>The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info).

<font size=4>**mSSIM** is the SSIM value calculated across the entire window of both images.

<font size=4>**The output below shows the SSIM maps with the mSSIM**

<font size = 4>**2. The RSE (Root Squared Error) map**

<font size = 4>This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).


<font size =4>**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.

<font size = 4>**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.

<font size=4>**The output below shows the RSE maps with the NRMSE and PSNR values.**





In [None]:

# ------------------------ User input ------------------------
#@markdown ##Choose the folders that contain your Quality Control dataset
QC_image_folder = "" #@param{type:"string"}
QC_loc_folder = "" #@param{type:"string"}
#@markdown Get pixel size from file?
get_pixel_size_from_file = True #@param {type:"boolean"}
#@markdown Otherwise, use this value:
pixel_size = 100 #@param {type:"number"}

if get_pixel_size_from_file:
  pixel_size_INPUT = None
else:
  pixel_size_INPUT = pixel_size


# ------------------------ QC analysis loop over provided dataset ------------------------

savePath = os.path.join(QC_model_path, 'Quality Control')

# Open and create the csv file that will contain all the QC metrics
with open(os.path.join(savePath, os.path.basename(QC_model_path)+"_QC_metrics.csv"), "w", newline='') as file:
  writer = csv.writer(file)

  # Write the header in the csv file
  writer.writerow(["image #","Prediction v. GT mSSIM","WF v. GT mSSIM", "Prediction v. GT NRMSE","WF v. GT NRMSE", "Prediction v. GT PSNR", "WF v. GT PSNR"])

  # These lists will be used to collect all the metrics values per slice
  file_name_list = []
  slice_number_list = []
  mSSIM_GvP_list = []
  mSSIM_GvWF_list = []
  NRMSE_GvP_list = []
  NRMSE_GvWF_list = []
  PSNR_GvP_list = []
  PSNR_GvWF_list = []

  # Let's loop through the provided dataset in the QC folders

  for (imageFilename, locFilename) in zip(list_files(QC_image_folder, 'tif'), list_files(QC_loc_folder, 'csv')):
    print('--------------')
    print(imageFilename)
    print(locFilename)

    # Get the prediction
    batchFramePredictionLocalization(QC_image_folder, imageFilename, QC_model_path, savePath, pixel_size = pixel_size_INPUT)

    # test_model(QC_image_folder, imageFilename, QC_model_path, savePath, display=False);
    thisPrediction = io.imread(os.path.join(savePath, 'Predicted_'+imageFilename))
    thisWidefield = io.imread(os.path.join(savePath, 'Widefield_'+imageFilename))

    Mhr = thisPrediction.shape[0]
    Nhr = thisPrediction.shape[1]

    if pixel_size_INPUT == None:
      pixel_size, N, M = getPixelSizeTIFFmetadata(os.path.join(QC_image_folder,imageFilename))

    upsampling_factor = int(Mhr/M)
    print('Upsampling factor: '+str(upsampling_factor))
    pixel_size_hr = pixel_size/upsampling_factor # in nm

    # Load the localization file and display the first
    LocData = pd.read_csv(os.path.join(QC_loc_folder,locFilename), index_col=0)

    x = np.array(list(LocData['x [nm]']))
    y = np.array(list(LocData['y [nm]']))
    locImage = FromLoc2Image_SimpleHistogram(x, y, image_size = (Mhr,Nhr), pixel_size = pixel_size_hr)

    # Remove extension from filename
    imageFilename_no_extension = os.path.splitext(imageFilename)[0]

    # io.imsave(os.path.join(savePath, 'GT_image_'+imageFilename), locImage)
    saveAsTIF(savePath, 'GT_image_'+imageFilename_no_extension, locImage, pixel_size_hr)

    # Normalize the images wrt each other by minimizing the MSE between GT and prediction
    test_GT_norm, test_prediction_norm = norm_minmse(locImage, thisPrediction, normalize_gt=True)
    # Normalize the images wrt each other by minimizing the MSE between GT and Source image
    test_GT_norm, test_wf_norm = norm_minmse(locImage, thisWidefield, normalize_gt=True)

    # -------------------------------- Calculate the metric maps and save them --------------------------------

    # Calculate the SSIM maps
    index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1., full=True)
    index_SSIM_GTvsWF, img_SSIM_GTvsWF = structural_similarity(test_GT_norm, test_wf_norm, data_range=1., full=True)


    # Save ssim_maps
    img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)
    # io.imsave(os.path.join(savePath,'SSIM_GTvsPrediction_'+imageFilename),img_SSIM_GTvsPrediction_32bit)
    saveAsTIF(savePath,'SSIM_GTvsPrediction_'+imageFilename_no_extension, img_SSIM_GTvsPrediction_32bit, pixel_size_hr)


    img_SSIM_GTvsWF_32bit = np.float32(img_SSIM_GTvsWF)
    # io.imsave(os.path.join(savePath,'SSIM_GTvsWF_'+imageFilename),img_SSIM_GTvsWF_32bit)
    saveAsTIF(savePath,'SSIM_GTvsWF_'+imageFilename_no_extension, img_SSIM_GTvsWF_32bit, pixel_size_hr)


    # Calculate the Root Squared Error (RSE) maps
    img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))
    img_RSE_GTvsWF = np.sqrt(np.square(test_GT_norm - test_wf_norm))

    # Save SE maps
    img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)
    # io.imsave(os.path.join(savePath,'RSE_GTvsPrediction_'+imageFilename),img_RSE_GTvsPrediction_32bit)
    saveAsTIF(savePath,'RSE_GTvsPrediction_'+imageFilename_no_extension, img_RSE_GTvsPrediction_32bit, pixel_size_hr)

    img_RSE_GTvsWF_32bit = np.float32(img_RSE_GTvsWF)
    # io.imsave(os.path.join(savePath,'RSE_GTvsWF_'+imageFilename),img_RSE_GTvsWF_32bit)
    saveAsTIF(savePath,'RSE_GTvsWF_'+imageFilename_no_extension, img_RSE_GTvsWF_32bit, pixel_size_hr)


    # -------------------------------- Calculate the RSE metrics and save them --------------------------------

    # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)
    NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))
    NRMSE_GTvsWF = np.sqrt(np.mean(img_RSE_GTvsWF))

    # We can also measure the peak signal to noise ratio between the images
    PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)
    PSNR_GTvsWF = psnr(test_GT_norm,test_wf_norm,data_range=1.0)

    writer.writerow([imageFilename,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsWF),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsWF),str(PSNR_GTvsPrediction), str(PSNR_GTvsWF)])

    # Collect values to display in dataframe output
    file_name_list.append(imageFilename)
    mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)
    mSSIM_GvWF_list.append(index_SSIM_GTvsWF)
    NRMSE_GvP_list.append(NRMSE_GTvsPrediction)
    NRMSE_GvWF_list.append(NRMSE_GTvsWF)
    PSNR_GvP_list.append(PSNR_GTvsPrediction)
    PSNR_GvWF_list.append(PSNR_GTvsWF)


# Table with metrics as dataframe output
pdResults = pd.DataFrame(index = file_name_list)
pdResults["Prediction v. GT mSSIM"] = mSSIM_GvP_list
pdResults["Wide-field v. GT mSSIM"] = mSSIM_GvWF_list
pdResults["Prediction v. GT NRMSE"] = NRMSE_GvP_list
pdResults["Wide-field v. GT NRMSE"] = NRMSE_GvWF_list
pdResults["Prediction v. GT PSNR"] = PSNR_GvP_list
pdResults["Wide-field v. GT PSNR"] = PSNR_GvWF_list


# ------------------------ Display ------------------------

print('--------------------------------------------')
@interact
def show_QC_results(file = list_files(QC_image_folder, 'tif')):

  plt.figure(figsize=(15,15))
  # Target (Ground-truth)
  plt.subplot(3,3,1)
  plt.axis('off')
  img_GT = io.imread(os.path.join(savePath, 'GT_image_'+file))
  plt.imshow(img_GT, norm = simple_norm(img_GT, percent = 99.5))
  plt.title('Target',fontsize=15)

  # Wide-field
  plt.subplot(3,3,2)
  plt.axis('off')
  img_Source = io.imread(os.path.join(savePath, 'Widefield_'+file))
  plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))
  plt.title('Widefield',fontsize=15)

  #Prediction
  plt.subplot(3,3,3)
  plt.axis('off')
  img_Prediction = io.imread(os.path.join(savePath, 'Predicted_'+file))
  plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))
  plt.title('Prediction',fontsize=15)

  #Setting up colours
  cmap = plt.cm.CMRmap

  #SSIM between GT and Source
  plt.subplot(3,3,5)
  #plt.axis('off')
  plt.tick_params(
      axis='both',      # changes apply to the x-axis and y-axis
      which='both',      # both major and minor ticks are affected
      bottom=False,      # ticks along the bottom edge are off
      top=False,        # ticks along the top edge are off
      left=False,       # ticks along the left edge are off
      right=False,         # ticks along the right edge are off
      labelbottom=False,
      labelleft=False)
  img_SSIM_GTvsWF = io.imread(os.path.join(savePath, 'SSIM_GTvsWF_'+file))
  imSSIM_GTvsWF = plt.imshow(img_SSIM_GTvsWF, cmap = cmap, vmin=0, vmax=1)
  plt.colorbar(imSSIM_GTvsWF,fraction=0.046, pad=0.04)
  plt.title('Target vs. Widefield',fontsize=15)
  plt.xlabel('mSSIM: '+str(round(pdResults.loc[file]["Wide-field v. GT mSSIM"],3)),fontsize=14)
  plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)

  #SSIM between GT and Prediction
  plt.subplot(3,3,6)
  #plt.axis('off')
  plt.tick_params(
      axis='both',      # changes apply to the x-axis and y-axis
      which='both',      # both major and minor ticks are affected
      bottom=False,      # ticks along the bottom edge are off
      top=False,        # ticks along the top edge are off
      left=False,       # ticks along the left edge are off
      right=False,         # ticks along the right edge are off
      labelbottom=False,
      labelleft=False)
  img_SSIM_GTvsPrediction = io.imread(os.path.join(savePath, 'SSIM_GTvsPrediction_'+file))
  imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)
  plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)
  plt.title('Target vs. Prediction',fontsize=15)
  plt.xlabel('mSSIM: '+str(round(pdResults.loc[file]["Prediction v. GT mSSIM"],3)),fontsize=14)

  #Root Squared Error between GT and Source
  plt.subplot(3,3,8)
  #plt.axis('off')
  plt.tick_params(
      axis='both',      # changes apply to the x-axis and y-axis
      which='both',      # both major and minor ticks are affected
      bottom=False,      # ticks along the bottom edge are off
      top=False,        # ticks along the top edge are off
      left=False,       # ticks along the left edge are off
      right=False,         # ticks along the right edge are off
      labelbottom=False,
      labelleft=False)
  img_RSE_GTvsWF = io.imread(os.path.join(savePath, 'RSE_GTvsWF_'+file))
  imRSE_GTvsWF = plt.imshow(img_RSE_GTvsWF, cmap = cmap, vmin=0, vmax = 1)
  plt.colorbar(imRSE_GTvsWF,fraction=0.046,pad=0.04)
  plt.title('Target vs. Widefield',fontsize=15)
  plt.xlabel('NRMSE: '+str(round(pdResults.loc[file]["Wide-field v. GT NRMSE"],3))+', PSNR: '+str(round(pdResults.loc[file]["Wide-field v. GT PSNR"],3)),fontsize=14)
  plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)

  #Root Squared Error between GT and Prediction
  plt.subplot(3,3,9)
  #plt.axis('off')
  plt.tick_params(
      axis='both',      # changes apply to the x-axis and y-axis
      which='both',      # both major and minor ticks are affected
      bottom=False,      # ticks along the bottom edge are off
      top=False,        # ticks along the top edge are off
      left=False,       # ticks along the left edge are off
      right=False,         # ticks along the right edge are off
      labelbottom=False,
      labelleft=False)
  img_RSE_GTvsPrediction = io.imread(os.path.join(savePath, 'RSE_GTvsPrediction_'+file))
  imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)
  plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)
  plt.title('Target vs. Prediction',fontsize=15)
  plt.xlabel('NRMSE: '+str(round(pdResults.loc[file]["Prediction v. GT NRMSE"],3))+', PSNR: '+str(round(pdResults.loc[file]["Prediction v. GT PSNR"],3)),fontsize=14)
  plt.savefig(QC_model_path+'/Quality Control/QC_example_data.png', bbox_inches='tight', pad_inches=0)
print('--------------------------------------------')
pdResults.head()

# Export pdf wth summary of QC results
qc_pdf_export()

# **6. Using the trained model**

---

<font size = 4>In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.

## **6.1 Generate image prediction and localizations from unseen dataset**
---

<font size = 4>The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).

<font size = 4>**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.

<font size = 4>**`Result_folder`:** This folder will contain the found localizations csv.

<font size = 4>**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 4**

<font size = 4>**`threshold`:** This paramter determines threshold for local maxima finding. The value is expected to reside in the range **[0,1]**. A higher `threshold` will result in less localizations. **DEFAULT: 0.1**

<font size = 4>**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**

<font size = 4>**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**


In [None]:

# ------------------------------- User input -------------------------------
#@markdown ### Data parameters
Data_folder = "" #@param {type:"string"}
Result_folder = "" #@param {type:"string"}
#@markdown Get pixel size from file?
get_pixel_size_from_file = True #@param {type:"boolean"}
#@markdown Otherwise, use this value (in nm):
pixel_size = 100 #@param {type:"number"}

#@markdown ### Model parameters
#@markdown Do you want to use the model you just trained?
Use_the_current_trained_model = True #@param {type:"boolean"}
#@markdown Otherwise, please provide path to the model folder below
prediction_model_path = "" #@param {type:"string"}

#@markdown ### Prediction parameters
batch_size =  4#@param {type:"integer"}

#@markdown ### Post processing parameters
threshold =  0.1#@param {type:"number"}
neighborhood_size =  3#@param {type:"integer"}
#@markdown Do you want to locally average the model output with CoG estimator ?
use_local_average = True #@param {type:"boolean"}


if get_pixel_size_from_file:
  pixel_size = None

if (Use_the_current_trained_model):
  prediction_model_path = os.path.join(model_path, model_name)

if os.path.exists(prediction_model_path):
  print("The "+os.path.basename(prediction_model_path)+" model will be used.")
else:
  print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)
  print('Please make sure you provide a valid model path before proceeding further.')

# inform user whether local averaging is being used
if use_local_average == True:
  print('Using local averaging')

if not os.path.exists(Result_folder):
  print('Result folder was created.')
  os.makedirs(Result_folder)


# ------------------------------- Run predictions -------------------------------

start = time.time()
#%% This script tests the trained fully convolutional network based on the
# saved training weights, and normalization created using train_model.

if os.path.isdir(Data_folder):
  for filename in list_files(Data_folder, 'tif'):
    # run the testing/reconstruction process
    print("------------------------------------")
    print("Running prediction on: "+ filename)
    batchFramePredictionLocalization(Data_folder, filename, prediction_model_path, Result_folder,
                                     batch_size,
                                     threshold,
                                     neighborhood_size,
                                     use_local_average,
                                     pixel_size = pixel_size)

elif os.path.isfile(Data_folder):
  batchFramePredictionLocalization(os.path.dirname(Data_folder), os.path.basename(Data_folder), prediction_model_path, Result_folder,
                                   batch_size,
                                   threshold,
                                   neighborhood_size,
                                   use_local_average,
                                   pixel_size = pixel_size)



print('--------------------------------------------------------------------')
# Displaying the time elapsed for training
dt = time.time() - start
minutes, seconds = divmod(dt, 60)
hours, minutes = divmod(minutes, 60)
print("Time elapsed:",hours, "hour(s)",minutes,"min(s)",round(seconds),"sec(s)")


# ------------------------------- Interactive display -------------------------------

print('--------------------------------------------------------------------')
print('---------------------------- Previews ------------------------------')
print('--------------------------------------------------------------------')

if os.path.isdir(Data_folder):
  @interact
  def show_QC_results(file = list_files(Data_folder, 'tif')):

    plt.figure(figsize=(15,7.5))
    # Wide-field
    plt.subplot(1,2,1)
    plt.axis('off')
    img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+file))
    plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))
    plt.title('Widefield', fontsize=15)
    # Prediction
    plt.subplot(1,2,2)
    plt.axis('off')
    img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+file))
    plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))
    plt.title('Predicted',fontsize=15)

if os.path.isfile(Data_folder):

  plt.figure(figsize=(15,7.5))
  # Wide-field
  plt.subplot(1,2,1)
  plt.axis('off')
  img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+os.path.basename(Data_folder)))
  plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))
  plt.title('Widefield', fontsize=15)
  # Prediction
  plt.subplot(1,2,2)
  plt.axis('off')
  img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+os.path.basename(Data_folder)))
  plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))
  plt.title('Predicted',fontsize=15)



## **6.2 Drift correction**
---

<font size = 4>The visualization above is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. The display is a preview without any drift correction applied. This section performs drift correction using cross-correlation between time bins to estimate the drift.

<font size = 4>**`Loc_file_path`:** is the path to the localization file to use for visualization.

<font size = 4>**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.

<font size = 4>**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the image reconstructions used for the Drift Correction estmication (in **nm**). A smaller pixel size will be more precise but will take longer to compute. **DEFAULT: 20**

<font size = 4>**`number_of_bins`:** This parameter defines how many temporal bins are used across the full dataset. All localizations in each bins are used ot build an image. This image is used to find the drift with respect to the image obtained from the very first bin. A typical value would correspond to about 500 frames per bin. **DEFAULT: Total number of frames / 500**

<font size = 4>**`polynomial_fit_degree`:** The drift obtained for each temporal bins needs to be interpolated to every single frames. This is performed by polynomial fit, the degree of which is defined here. **DEFAULT: 4**

<font size = 4> The drift-corrected localization data is automaticaly saved in the `save_path` folder.

In [None]:
# @markdown ##Data parameters
Loc_file_path = "" #@param {type:"string"}
# @markdown Provide information about original data. Get the info automatically from the raw data?
Get_info_from_file = True #@param {type:"boolean"}
# Loc_file_path = "/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv" #@param {type:"string"}
original_image_path = "" #@param {type:"string"}
# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)
image_width =  256#@param {type:"integer"}
image_height =  256#@param {type:"integer"}
pixel_size = 100 #@param {type:"number"}

# @markdown ##Drift correction parameters
visualization_pixel_size =  20#@param {type:"number"}
number_of_bins =  50#@param {type:"integer"}
polynomial_fit_degree =  4#@param {type:"integer"}

# @markdown ##Saving parameters
save_path = '' #@param {type:"string"}


# Let's go !
start = time.time()

# Get info from the raw file if selected
if Get_info_from_file:
  pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)

# Read the localizations in
LocData = pd.read_csv(Loc_file_path)

# Calculate a few variables
Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))
Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))
nFrames = max(LocData['frame'])
x_max = max(LocData['x [nm]'])
y_max = max(LocData['y [nm]'])
image_size = (Mhr, Nhr)
n_locs = len(LocData.index)

print('Image size: '+str(image_size))
print('Number of frames in data: '+str(nFrames))
print('Number of localizations in data: '+str(n_locs))

blocksize = math.ceil(nFrames/number_of_bins)
print('Number of frames per block: '+str(blocksize))

blockDataFrame = LocData[(LocData['frame'] < blocksize)].copy()
xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)
yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)

# Preparing the Reference image
photon_array = np.ones(yc_array.shape[0])
sigma_array = np.ones(yc_array.shape[0])
ImageRef = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)
ImagesRef = np.rot90(ImageRef, k=2)

xDrift = np.zeros(number_of_bins)
yDrift = np.zeros(number_of_bins)

filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]

with open(os.path.join(save_path, filename_no_extension+"_DriftCorrectionData.csv"), "w", newline='') as file:
  writer = csv.writer(file)

  # Write the header in the csv file
  writer.writerow(["Block #", "x-drift [nm]","y-drift [nm]"])

  for b in tqdm(range(number_of_bins)):

    blockDataFrame = LocData[(LocData['frame'] >= (b*blocksize)) & (LocData['frame'] < ((b+1)*blocksize))].copy()
    xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)
    yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)

    photon_array = np.ones(yc_array.shape[0])
    sigma_array = np.ones(yc_array.shape[0])
    ImageBlock = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)

    XC = fftconvolve(ImagesRef, ImageBlock, mode = 'same')
    yDrift[b], xDrift[b] = subPixelMaxLocalization(XC, method = 'CoM')

    # saveAsTIF(save_path, 'ImageBlock'+str(b), ImageBlock, visualization_pixel_size)
    # saveAsTIF(save_path, 'XCBlock'+str(b), XC, visualization_pixel_size)
    writer.writerow([str(b), str((xDrift[b]-xDrift[0])*visualization_pixel_size), str((yDrift[b]-yDrift[0])*visualization_pixel_size)])


print('--------------------------------------------------------------------')
# Displaying the time elapsed for training
dt = time.time() - start
minutes, seconds = divmod(dt, 60)
hours, minutes = divmod(minutes, 60)
print("Time elapsed:",hours, "hour(s)",minutes,"min(s)",round(seconds),"sec(s)")

print('Fitting drift data...')
bin_number = np.arange(number_of_bins)*blocksize + blocksize/2
xDrift = (xDrift-xDrift[0])*visualization_pixel_size
yDrift = (yDrift-yDrift[0])*visualization_pixel_size

xDriftCoeff = np.polyfit(bin_number, xDrift, polynomial_fit_degree)
yDriftCoeff = np.polyfit(bin_number, yDrift, polynomial_fit_degree)

xDriftFit = np.poly1d(xDriftCoeff)
yDriftFit = np.poly1d(yDriftCoeff)
bins = np.arange(nFrames)
xDriftInterpolated = xDriftFit(bins)
yDriftInterpolated = yDriftFit(bins)


# ------------------ Displaying the image results ------------------

plt.figure(figsize=(15,10))
plt.plot(bin_number,xDrift, 'r+', label='x-drift')
plt.plot(bin_number,yDrift, 'b+', label='y-drift')
plt.plot(bins,xDriftInterpolated, 'r-', label='y-drift (fit)')
plt.plot(bins,yDriftInterpolated, 'b-', label='y-drift (fit)')
plt.title('Cross-correlation estimated drift')
plt.ylabel('Drift [nm]')
plt.xlabel('Bin number')
plt.legend();

dt = time.time() - start
minutes, seconds = divmod(dt, 60)
hours, minutes = divmod(minutes, 60)
print("Time elapsed:", hours, "hour(s)",minutes,"min(s)",round(seconds),"sec(s)")


# ------------------ Actual drift correction -------------------

print('Correcting localization data...')
xc_array = LocData['x [nm]'].to_numpy(dtype=np.float32)
yc_array = LocData['y [nm]'].to_numpy(dtype=np.float32)
frames = LocData['frame'].to_numpy(dtype=np.int32)


xc_array_Corr, yc_array_Corr = correctDriftLocalization(xc_array, yc_array, frames, xDriftInterpolated, yDriftInterpolated)
ImageRaw = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)
ImageCorr = FromLoc2Image_SimpleHistogram(xc_array_Corr, yc_array_Corr, image_size = image_size, pixel_size = visualization_pixel_size)


# ------------------ Displaying the imge results ------------------
plt.figure(figsize=(15,7.5))
# Raw
plt.subplot(1,2,1)
plt.axis('off')
plt.imshow(ImageRaw, norm = simple_norm(ImageRaw, percent = 99.5))
plt.title('Raw', fontsize=15);
# Corrected
plt.subplot(1,2,2)
plt.axis('off')
plt.imshow(ImageCorr, norm = simple_norm(ImageCorr, percent = 99.5))
plt.title('Corrected',fontsize=15);


# ------------------ Table with info -------------------
driftCorrectedLocData = pd.DataFrame()
driftCorrectedLocData['frame'] = frames
driftCorrectedLocData['x [nm]'] = xc_array_Corr
driftCorrectedLocData['y [nm]'] = yc_array_Corr
driftCorrectedLocData['confidence [a.u]'] = LocData['confidence [a.u]']

driftCorrectedLocData.to_csv(os.path.join(save_path, filename_no_extension+'_DriftCorrected.csv'))
print('-------------------------------')
print('Corrected localizations saved.')


## **6.3 Visualization of the localizations**
---


<font size = 4>The visualization in section 6.1 is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. This section performs visualization of the result by plotting the localizations as a simple histogram.

<font size = 4>**`Loc_file_path`:** is the path to the localization file to use for visualization.

<font size = 4>**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.

<font size = 4>**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the final image reconstruction (in **nm**). **DEFAULT: 10**

<font size = 4>**`visualization_mode`:** This parameter defines what visualization method is used to visualize the final image. NOTES: The Integrated Gaussian can be quite slow. **DEFAULT: Simple histogram.**





In [None]:
# @markdown ##Data parameters
Use_current_drift_corrected_localizations = True #@param {type:"boolean"}
# @markdown Otherwise provide a localization file path
Loc_file_path = "" #@param {type:"string"}
# @markdown Provide information about original data. Get the info automatically from the raw data?
Get_info_from_file = True #@param {type:"boolean"}
# Loc_file_path = "/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv" #@param {type:"string"}
original_image_path = "" #@param {type:"string"}
# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)
image_width =  256#@param {type:"integer"}
image_height =  256#@param {type:"integer"}
pixel_size =  100#@param {type:"number"}

# @markdown ##Visualization parameters
visualization_pixel_size =  10#@param {type:"number"}
visualization_mode = "Simple histogram" #@param ["Simple histogram", "Integrated Gaussian (SLOW!)"]

if not Use_current_drift_corrected_localizations:
  filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]


if Get_info_from_file:
  pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)

if Use_current_drift_corrected_localizations:
  LocData = driftCorrectedLocData
else:
  LocData = pd.read_csv(Loc_file_path)

Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))
Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))


nFrames = max(LocData['frame'])
x_max = max(LocData['x [nm]'])
y_max = max(LocData['y [nm]'])
image_size = (Mhr, Nhr)

print('Image size: '+str(image_size))
print('Number of frames in data: '+str(nFrames))
print('Number of localizations in data: '+str(len(LocData.index)))

xc_array = LocData['x [nm]'].to_numpy()
yc_array = LocData['y [nm]'].to_numpy()
if (visualization_mode == 'Simple histogram'):
  locImage = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)
elif (visualization_mode == 'Shifted histogram'):
  print(bcolors.WARNING+'Method not implemented yet!'+bcolors.NORMAL)
  locImage = np.zeros(image_size)
elif (visualization_mode == 'Integrated Gaussian (SLOW!)'):
  photon_array = np.ones(xc_array.shape)
  sigma_array = np.ones(xc_array.shape)
  locImage = FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = image_size, pixel_size = visualization_pixel_size)

print('--------------------------------------------------------------------')
# Displaying the time elapsed for training
dt = time.time() - start
minutes, seconds = divmod(dt, 60)
hours, minutes = divmod(minutes, 60)
print("Time elapsed:",hours, "hour(s)",minutes,"min(s)",round(seconds),"sec(s)")

# Display
plt.figure(figsize=(20,10))
plt.axis('off')
# plt.imshow(locImage, cmap='gray');
plt.imshow(locImage, norm = simple_norm(locImage, percent = 99.5));


LocData.head()



In [None]:
# @markdown ---

# @markdown #Play this cell to save the visualization

# @markdown ####Please select a path to the folder where to save the visualization.
save_path = "" #@param {type:"string"}

if not os.path.exists(save_path):
  os.makedirs(save_path)
  print('Folder created.')

saveAsTIF(save_path, filename_no_extension+'_Visualization', locImage, visualization_pixel_size)
print('Image saved.')

## **6.4. Download your predictions**
---

<font size = 4>**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name.

# **7. Version log**
---
<font size = 4>**v1.13.3**:  

*    Convert input variables into numerical data.
*    Check variables in globals rather than locals.

<font size = 4>**v1.13.2**:  

*    Replaced all absolute pathing with relative pathing

<font size = 4>**v1.13**:
* The section 1 and 2 are now swapped for better export of *requirements.txt*.
* This version also now includes built-in version check and the version log that you're reading now.

---


# **Thank you for using Deep-STORM 2D!**