<b><big>Step 3 - Restore</big><b>

Make sure that you have run CombinedP1P2Stitch.ijm in ImageJ and have the stitched files ready.

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
import tkinter as tk
import math
import os
%gui tk
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from tkinter import filedialog
from csbdeep.utils import Path, plot_some
from csbdeep.io import save_tiff_imagej_compatible
from csbdeep.models import CARE
from IPython.display import clear_output
from tqdm import tqdm

def next_power_of_2(x):
    # Return the next power of 2 from int input.
    if type(x) != int or x < 0:
        raise TypeError("Input must be positive int.")
    return 1 if x == 0 else 2**(x - 1).bit_length()

def subdivide():
    # Return number of subdivisions in xyz axes to reduce GRAM consumption.
    product = 1
    for i in size: product *= i
    ref = 10*512*512
    target = math.sqrt(product / ref)
    zslice = 1
    yslice = next_power_of_2(math.floor(target)+1)
    xslice = int(yslice / 2)
    return (zslice, yslice, xslice)

root = tk.Tk()
root.withdraw()
root.call('wm', 'attributes', '.', '-topmost', True)

chan = int(input("Number of channels: "))
maxproj = input("Max Project? y or Yes, n for No.")

if maxproj != "y" and maxproj != "n":
    raise ValueError("Only type y or n.")

directory = filedialog.askdirectory(title="Choose folder directory")

# Create list of csbdeep models
models = list()
for i in range(chan):
    i += 1
    modeldir = filedialog.askdirectory(title="Choose model for channel " + str(i),
                                          initialdir=os.path.dirname("models/"))
    models.append(CARE(config=None, name=str.split(modeldir, "/")[-1], basedir=modeldir + "/../"))

# Count total number of files for progress bar
filecounter = 0
for file in os.listdir(directory):
    if file.endswith("tif"): # Processed folder from step 1
        filecounter += 1

# Iterate through all images in the folder.
for file in tqdm(os.listdir(directory), total=filecounter, unit="files"):
    # Must be .tif
    if file.endswith(".tif"):
        current = imread(directory + "/" + file)
        size = current.shape
        result = np.zeros(size)
        print(file)
        if chan == 1: # 1 channel: no C channel
            result = models[0].predict(current, 'ZYX', n_tiles=subdivide())
        else: # >1 channels: ZCYX, convert ZCYX to multiple ZYX to allow for separate model
            size = (size[0],) + size[2:] # Convert ZCYX to ZYX
            for c in range(chan):
                print('Channel ' + str(c+1))
                result[:,c,:,:] = models[c].predict(current[:,c,:,:], 'ZYX', n_tiles=subdivide())
        Path(directory + "/Restored/").mkdir(exist_ok=True)
        options = 'ZYX' if chan == 1 else 'ZCYX'
        # Max project
        if maxproj == "y":
            save_tiff_imagej_compatible(directory + "/Restored/%s" % file, np.amax(result, axis=0, keepdims=True), options)
        else:
            save_tiff_imagej_compatible(directory + "/Restored/%s" % file, result, options)
        clear_output(wait=True)
print("Done!")