In [None]:
import os
import shutil
from PIL import Image
from tifffile import imread

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from csbdeep.utils import download_and_extract_zip_file, plot_some
from csbdeep.data import RawData, create_patches
from csbdeep.data import no_background_patches, norm_percentiles, sample_percentiles

In [None]:
from tifffile import imread
from csbdeep.utils import axes_dict, plot_some, plot_history
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.io import load_training_data
from csbdeep.models import Config, CARE

# Newly selected data

In [None]:
raw_data = RawData.from_folder (
    basepath    = 'Gauss',
    source_dirs = ['train_lr_tif_newselect'],
    target_dir  = 'train_hr_tif_newselect',
    axes        = 'YX',
)

In [None]:
X, Y, XY_axes = create_patches (
    raw_data            = raw_data,
    patch_size          = (64,100),
    patch_filter        = no_background_patches(0),
    n_patches_per_image = 2,
    save_file           = 'Gauss/liver_newselect_train_norm.npz',
)

In [None]:
assert X.shape == Y.shape
print("shape of X,Y =", X.shape)
print("axes  of X,Y =", XY_axes)

In [None]:
for i in range(2):
    plt.figure(figsize=(16,4))
    sl = slice(8*i, 8*(i+1)), 0
    plot_some(X[sl],Y[sl],title_list=[np.arange(sl[0].start,sl[0].stop)])
    plt.show()
None;

## Training

In [None]:
(X,Y), (X_val,Y_val), axes = load_training_data('Gauss/liver_newselect_train_norm.npz', validation_split=0.2, verbose=True)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

In [None]:
plt.figure(figsize=(12,5))
plot_some(X_val[0:5],Y_val[0:5])
plt.suptitle('5 example validation patches (top row: source, bottom row: target)');

In [None]:
config = Config(axes, n_channel_in, n_channel_out,  
                train_batch_size = 8, train_epochs = 30, train_steps_per_epoch = 400)
print(config)
vars(config)

In [None]:
model = CARE(config, '400steps_30epochs_newselect', basedir='models')
model.keras_model.summary()

In [None]:
history = model.train(X,Y, validation_data=(X_val,Y_val))

In [None]:
model.export_TF()

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);

## Predict

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from tifffile import imsave
from csbdeep.utils import Path, download_and_extract_zip_file, plot_some
from csbdeep.io import save_tiff_imagej_compatible
from csbdeep.models import CARE

from tqdm import tqdm

In [None]:
test_lr_dir = 'Gauss/test_lr_tif/'
os.listdir(test_lr_dir)[10]

In [None]:
model = CARE(config=None, name='400steps_30epochs_newselect', basedir='models')

In [None]:
from os import listdir
 
def list_files(directory, extension):
    return (f for f in listdir(directory) if f.endswith('.' + extension))

directory = "Gauss/test_lr_tif/"
files = list_files(directory, "tif")
filenames=[]
for f in files:
    filenames.append(f)

print(len(filenames))

In [None]:
filenames[0]

In [None]:
%%time

output_dir =  "Gauss/restored_liver_400steps_30epochs_newselect/"
for image in filenames:
    x = imread(directory+image)
    restored = model.predict(x, axes='YX')
    imsave(output_dir+image, restored)

In [None]:
y = imread('Gauss/test_hr_tif/liver1peak759.tif')
x = imread('Gauss/test_lr_tif/liver1peak759.tif')
check = imread('Gauss/restored_liver_400steps_30epochs_newselect/liver1peak759.tif')

axes = 'YX'
print('image size =', x.shape)
print('image axes =', axes)

plt.figure(figsize=(20,5))
plt.subplot(1,3,1)
plt.imshow(x, cmap  ="magma")
plt.colorbar()
plt.title("low")
plt.subplot(1,3,2)
plt.imshow(y, cmap  ="magma")
plt.colorbar()
plt.title("high")
plt.subplot(1,3,3)
plt.imshow(check, cmap  ="magma")
plt.colorbar()
plt.clim(0, 1)
plt.title("restored");

In [None]:
from csbdeep.utils import normalize


plt.figure(figsize=(10,5))
for _x,_name in zip((x,check,y),('low','CARE','GT')):
    plt.plot(normalize(_x,1,99.7)[45], label = _name, lw = 2)

plt.legend();

In [None]:
os.listdir('Gauss/test_hr_tif')[:4]

In [None]:
y = imread('Gauss/test_hr_tif/liver1peak108.tif')
x = imread('Gauss/test_lr_tif/liver1peak108.tif')
check = imread('Gauss/restored_liver_400steps_30epochs_newselect/liver1peak108.tif')

plt.figure(figsize=(25,20))
plot_some(np.stack([x,check,y]),
          title_list=[['low','CARE','GT']], 
          pmin=2,pmax=99.8);

plt.figure(figsize=(10,5))
for _x,_name in zip((x,check,y),('low','CARE','GT')):
    plt.plot(normalize(_x,1,99.7)[45], label = _name, lw = 2)

plt.legend();

In [None]:
y = imread('Gauss/test_hr_tif/liver1peak286.tif')
x = imread('Gauss/test_lr_tif/liver1peak286.tif')
check = imread('Gauss/restored_liver_400steps_30epochs_newselect/liver1peak286.tif')

plt.figure(figsize=(25,20))
plot_some(np.stack([x,check,y]),
          title_list=[['low','CARE','GT']], 
          pmin=2,pmax=99.8);

plt.figure(figsize=(10,5))
for _x,_name in zip((x,check,y),('low','CARE','GT')):
    plt.plot(normalize(_x,1,99.7)[45], label = _name, lw = 2)

plt.legend();