# Predict cryo-CARE

In this notebook we will take the two reconstructed (even/odd) tomograms and apply the trained network to them. Afterwards we will average them voxel-wise to get our final restored tomogram.

In [None]:
from train_cryo_care import CryoCARE
from generate_train_data import normalize, compute_mean_std, denormalize

import mrcfile
import numpy as np
import os
import subprocess

from matplotlib import pyplot as plt

from glob import glob

In [None]:
# imports and settings for the GPU

os.environ["CUDA_VISIBLE_DEVICES"]="0" # <<< Set the GPU you want to use for the network here (0-7)

## Load Tomograms

In [None]:
# Load the two tomograms -  it searches for a file named half-tomo.rec, which is default imod name
# if you used SIRT, change the name accordingly (e.g. half-tomo_SIRT_iter_03.rec)
even = mrcfile.open(glob('frames/even/tomogram/half-tomo.rec')[0]).data
odd = mrcfile.open(glob('frames/odd/tomogram/half-tomo.rec')[0]).data

In [None]:
# We want to normalise the dataset such that the mean is zero, and standard deviation is 1. To do so we must
# first compute mean and std of the raw data.

mean, std = compute_mean_std(np.stack((even, odd))) 
print(mean, std) # lets see it for diagnostic purposes

In [None]:
# Normalize the data
even_n = normalize(even, mean, std)
odd_n = normalize(odd, mean, std)

## Load Network
We load the model trained in notebook [04]

In [None]:
model = CryoCARE(None, 'denoiser_model', basedir='')

## Apply Network
If you get an error message in model.predict that contains out of memory somewhere in the error stack, adjust n_tiles=(4,8,8) to larger values, (e.g. n_tiles = (8,8,8) etc)

In [None]:
# We denoise the normalized data and denormalize it. This means the intensities 
# are transformed back into the original data range.
# Note: We set 'normalizer=None' since we already normalized the data. 
even_denoised = denormalize(model.predict(even_n, axes='ZYX', n_tiles=(4,8,8), normalizer=None), mean, std)

In [None]:
odd_denoised = denormalize(model.predict(odd_n, axes='ZYX', n_tiles=(4,8,8), normalizer=None), mean, std)

In [None]:
# Voxel-wise averaging of the two halves.
tomo_denoised = (even_denoised + odd_denoised)/2.0
tomo_raw_average = (even + odd)/2.0 # for optional inspection only

In [None]:
# Create the output directory
if not os.path.isdir('output/'):
    os.mkdir('output/')

In [None]:
# Save the denoised tomogram.
with mrcfile.open('output/tomo_denoised.mrc', 'w+') as mrc:
    mrc.set_data(tomo_denoised)

## (optional) Quick inspection of results

In [None]:
# Print dimensions of tomogram, so you can pick valid coordinates in validation below
print(np.shape(even_denoised))

In [None]:
# specify area you wish to plot. It must result in a 2D array. (so one of indexes must be a single number)
area_coordinates = [slice(10,200),slice(10,200),8] # slice(a,b) is the same as a:b as argument in array index

plt.figure(figsize=(20,25))
my_z = 108
plt.subplot(3,2,1)
plt.imshow(even_n[area_coordinates], cmap='gray')
plt.title('Even Raw');
plt.subplot(3,2,2)
plt.imshow(even_denoised[area_coordinates], cmap='gray')
plt.title('Even Denoised');
plt.subplot(3,2,3)
plt.imshow(odd[area_coordinates], cmap='gray')
plt.title('Odd Raw');
plt.subplot(3,2,4)
plt.imshow(odd_denoised[area_coordinates], cmap='gray')
plt.title('Odd Denoised');
plt.subplot(3,2,5)
plt.imshow(((even+odd)/2.0)[area_coordinates], cmap='gray')
plt.title('Voxel-wise Average Raw');
plt.subplot(3,2,6)
plt.imshow(tomo_denoised[area_coordinates], cmap='gray')
plt.title('Voxel-wise Average Denoised');

In [None]:
# restart the kernel so the GPUs are freed - not a very elegant way
# this will result in a pop up dialog saying 'The kernel appears to have died. It will restart automatically.'
# you can ignore this and accept, the script should have completed successfully

exit()