In [None]:
#@markdown ##Run this cell to check if you have GPU access
# %tensorflow_version 1.x

import tensorflow as tf
if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi

In [None]:
Notebook_version = ['1.12']
import tensorflow
# ------- Variable specific to N2V -------
from n2v.models import N2VConfig, N2V
from csbdeep.utils import plot_history
from n2v.utils.n2v_utils import manipulate_val_data
from n2v.internals.N2V_DataGenerator import N2V_DataGenerator
from csbdeep.io import save_tiff_imagej_compatible

# ------- Common variable to all ZeroCostDL4Mic notebooks -------
import numpy as np
from matplotlib import pyplot as plt
import urllib
import os, random
import shutil 
import zipfile
from tifffile import imread, imsave
import time
import sys
import wget
from pathlib import Path
import pandas as pd
import csv
from glob import glob
from scipy import signal
from scipy import ndimage
from skimage import io
from sklearn.linear_model import LinearRegression
from skimage.util import img_as_uint
import matplotlib as mpl
from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio as psnr
from astropy.visualization import simple_norm
from skimage import img_as_float32
from fpdf import FPDF, HTMLMixin
from datetime import datetime
from pip._internal.operations.freeze import freeze
import subprocess
from datetime import datetime

# Colors for the warning messages
class bcolors:
  WARNING = '\033[31m'
W  = '\033[0m'  # white (normal)
R  = '\033[31m' # red

#Disable some of the tensorflow warnings
import warnings
warnings.filterwarnings("ignore")

print("Libraries installed")


# Check if this is the latest version of the notebook
Latest_notebook_version = pd.read_csv("https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv")
print('Notebook version: '+Notebook_version[0])
strlist = Notebook_version[0].split('.')
Notebook_version_main = strlist[0]+'.'+strlist[1]
if Notebook_version_main == Latest_notebook_version.columns:
  print("This notebook is up-to-date.")
else:
  print(bcolors.WARNING +"A new version of this notebook has been released. We recommend that you download it at https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki")


In [None]:
# create DataGenerator-object.

datagen = N2V_DataGenerator()

#@markdown ###Path to training image(s): 
Training_source = "Training" #@param {type:"string"}

#compatibility to easily change the name of the parameters
training_images = Training_source 
imgs = datagen.load_imgs_from_directory(directory = Training_source)

#@markdown ### Model name and path:
model_name = "model" #@param {type:"string"}Python 3.8.8 64-bit ('base': conda): Idle

model_path = "." #@param {type:"string"}

full_model_path = model_path+'/'+model_name+'/'

#@markdown ###Training Parameters
#@markdown Number of epochs:
number_of_epochs =  100#@param {type:"number"}

#@markdown Patch size (pixels)
patch_size =  64#@param {type:"number"}

#@markdown ###Advanced Parameters

Use_Default_Advanced_Parameters = True#@param {type:"boolean"}

#@markdown ###If not, please input:
batch_size =  128#@param {type:"number"}
number_of_steps = 100#@param {type:"number"}
percentage_validation =  10#@param {type:"number"}
initial_learning_rate = 0.0004 #@param {type:"number"}




In [None]:

if (Use_Default_Advanced_Parameters): 
  print("Default advanced parameters enabled")
  # number_of_steps is defined in the following cell in this case
  batch_size = 128
  percentage_validation = 10
  initial_learning_rate = 0.0004
 

#here we check that no model with the same name already exist, if so print a warning

if os.path.exists(model_path+'/'+model_name):
  print(bcolors.WARNING +"!! WARNING: "+model_name+" already exists and will be deleted in the following cell !!")
  print(bcolors.WARNING +"To continue training "+model_name+", choose a new model_name here, and load "+model_name+" in section 3.3"+W)
   

# This will open a randomly chosen dataset input image
random_choice = random.choice(os.listdir(Training_source))
x = imread(Training_source+"/"+random_choice)

# Here we check that the input images contains the expected dimensions
if len(x.shape) == 2:
  print("Image dimensions (y,x)",x.shape)

if not len(x.shape) == 2:
  print(bcolors.WARNING +"Your images appear to have the wrong dimensions. Image dimension",x.shape)


#Find image XY dimension
Image_Y = x.shape[0]
Image_X = x.shape[1]

#Hyperparameters failsafes

# Here we check that patch_size is smaller than the smallest xy dimension of the image 
if patch_size > min(Image_Y, Image_X):
  patch_size = min(Image_Y, Image_X)
  print (bcolors.WARNING + " Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:",patch_size)

# Here we check that patch_size is divisible by 8
if not patch_size % 8 == 0:
    patch_size = ((int(patch_size / 8)-1) * 8)
    print (bcolors.WARNING + " Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:",patch_size)

# Here we disable pre-trained model by default (in case the next cell is not run)
Use_pretrained_model = False

# Here we enable data augmentation by default (in case the cell is not ran)
Use_Data_augmentation = True

print("Parameters initiated.")

#Here we display one image
norm = simple_norm(x, percent = 99)

f=plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')
plt.title('Training source')
plt.axis('off');
plt.savefig('TrainingDataExample_N2V2D.png',bbox_inches='tight',pad_inches=0)