In [None]:
!pip install csbdeep

In [None]:
import tensorflow as tf
tf.test.gpu_device_name()

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import download_and_extract_zip_file, plot_some, axes_dict
from csbdeep.io import save_training_data
from csbdeep.data import RawData, create_patches
from csbdeep.data.transform import anisotropic_distortions


In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import axes_dict, plot_some, plot_history
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.io import load_training_data
from csbdeep.models import Config, UpsamplingCARE

# data preparation

In [None]:
raw_data = RawData.from_folder (
    basepath    = '/content/drive/MyDrive/Train_care/',
    source_dirs = ['LR'],
    target_dir  = 'HR',
    axes        = 'ZYX'
)

In [None]:
anisotropic_transform1 = anisotropic_distortions (
    subsample      = 4,
    psf            = None,
    subsample_axis = 'X',
    yield_target   = 'target',
)

In [None]:
anisotropic_transform2 = anisotropic_distortions (
    subsample      = 4,
    psf            = None,
    subsample_axis = 'Y',
    yield_target   = 'target',
)

In [None]:
X, Y, XY_axes = create_patches (
    raw_data            = raw_data,
    patch_size          = (4,64,100),
    n_patches_per_image = 380,
    transforms          = [anisotropic_transform1]
)

In [None]:
intermediate_data_0 = RawData.from_arrays(X[:,:,0,:,:], X[:,:,0,:,:], axes='CZYX')
X2_0, _, _ = create_patches (
  raw_data            = intermediate_data_0,
  patch_size          = (1,64,100),
  n_patches_per_image = 1,
  transforms          = [anisotropic_transform2],
  normalization = None,
  shuffle = False,
  patch_filter = None
)
intermediate_data_1 = RawData.from_arrays(X[:,:,1,:,:], X[:,:,1,:,:], axes='CZYX')
X2_1, _, _ = create_patches (
  raw_data            = intermediate_data_1,
  patch_size          = (1,64,100),
  n_patches_per_image = 1,
  transforms          = [anisotropic_transform2],
  normalization = None,
  shuffle = False,
  patch_filter = None
)
intermediate_data_2 = RawData.from_arrays(X[:,:,2,:,:], X[:,:,2,:,:], axes='CZYX')
X2_2, _, _ = create_patches (
  raw_data            = intermediate_data_2,
  patch_size          = (1,64,100),
  n_patches_per_image = 1,
  transforms          = [anisotropic_transform2],
  normalization = None,
  shuffle = False,
  patch_filter = None
)
intermediate_data_3 = RawData.from_arrays(X[:,:,3,:,:], X[:,:,3,:,:], axes='CZYX')
X2_3, _, _ = create_patches (
  raw_data            = intermediate_data_3,
  patch_size          = (1,64,100),
  n_patches_per_image = 1,
  transforms          = [anisotropic_transform2],
  normalization = None,
  shuffle = False,
  patch_filter = None
) 

In [None]:
assert X.shape == Y.shape
print("shape of X,Y =", X.shape)
print("axes  of X,Y =", XY_axes)

In [None]:
X_out = np.concatenate((X2_0,X2_1,X2_2,X2_3), axis=2)
X_out.shape

In [None]:
for i in range(2):
    plt.figure(figsize=(20,4))
#     sl = slice(8*i, 8*(i+1)), slice(None), slice(None), 0
    plot_some(X_out[20:25],Y[20:25])
    plt.show()
None;

In [None]:
save_training_data('/content/drive/MyDrive/Train_care/fulltrain_x4_nonoise.npz', X_out, Y, XY_axes)

## training

In [None]:
# val size is kept small to ensure maximal training
(X,Y), (X_val,Y_val), axes = load_training_data('/content/drive/MyDrive/Train_care/fulltrain_x4_nonoise.npz', 
                                                validation_split = 0.05, verbose=True)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

In [None]:
plt.figure(figsize=(12,3))
plot_some(X_val[10:15],Y_val[10:15])
plt.suptitle('5 example validation patches (ZY slice, top row: source, bottom row: target)');

In [None]:
config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch = 200, train_batch_size = 8)
print(config)
vars(config)

In [None]:
model = UpsamplingCARE(config, 'fulltrain_x4_nonoise', basedir='/content/drive/MyDrive/CARE_models')

In [None]:
history = model.train(X,Y, validation_data=(X_val,Y_val))

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);

In [None]:
plt.figure(figsize=(20,12))
_P = model.keras_model.predict(X_val[10:15])
if config.probabilistic:
    _P = _P[...,:(_P.shape[-1]//2)]
plot_some(X_val[10:15],Y_val[10:15],_P,pmax=99.5)
plt.suptitle('5 example validation patches (ZY slice)\n'      
             'top row: input (source),  '          
             'middle row: target (ground truth),  '
             'bottom row: predicted from source');

# get liver1 predictions

In [None]:

from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import Path, download_and_extract_zip_file, plot_some
from csbdeep.io import save_tiff_imagej_compatible
from csbdeep.models import UpsamplingCARE

In [None]:
y = imread('drive/MyDrive/liver1_test/test_hr_stacked/liver1.tif')
x = imread('drive/MyDrive/liver1_test/test_lr_stacked/liver1.tif')

axes = 'ZYX'
print('image size =', x.shape)
print('image axes =', axes)

In [None]:
model = CARE(config=None, name='newselect_x4_nonoise/', basedir='drive/MyDrive/CARE_models/')

In [None]:
save_tiff_imagej_compatible('liver1_carerestored.tiff', restored, axes)

In [None]:
new_model = UpsamplingCARE(config=None, name='newselect_x4_nonoise/', basedir='drive/MyDrive/CARE_models/')

## option 1

In [None]:
x = imread('drive/MyDrive/liver1_test/CARE_x4/liver1.tif')

In [None]:
restored_1st = new_model.predict(x, 'YXZ', 4, n_tiles=(2,2,2))
print('1st output size =', restored_1st.shape)

restored_2nd = new_model.predict(restored_1st, 'YZX', 4,  n_tiles=(2,2,2))
print('2nd output size =', restored_2nd.shape)

In [None]:
save_tiff_imagej_compatible('/content/drive/MyDrive/liver1_test/CARE_x4/liver1_enhanced.tiff', restored_2nd, 'ZYX')

## option 2

In [None]:
restored_1st = new_model.predict(x, 'XYZ', 4)
print('1st output size =', restored_1st.shape)

restored_2nd = new_model.predict(restored_1st, 'XZY', 4)
print('2nd output size =', restored_2nd.shape)

In [None]:
save_tiff_imagej_compatible('liver1_option2.tiff', restored_2nd, 'ZYX')

## option 3

In [None]:
restored_1st = new_model.predict(x, 'YXZ', 4)
print('1st output size =', restored_1st.shape)

restored_2nd = new_model.predict(restored_1st, 'XZY', 4)
print('2nd output size =', restored_2nd.shape)

In [None]:
save_tiff_imagej_compatible('liver1_option3.tiff', restored_2nd, 'ZYX')

## option 4

In [None]:
restored_1st = new_model.predict(x, 'XYZ', 4)
print('1st output size =', restored_1st.shape)

restored_2nd = new_model.predict(restored_1st, 'YZX', 4)
print('2nd output size =', restored_2nd.shape)

In [None]:
save_tiff_imagej_compatible('liver1_option4.tiff', restored_2nd, 'ZYX')