# Network Prediction
Please run the ```Convallaria-Training.ipynb``` before.

In [None]:
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import numpy as np
import torch
from tifffile import imread
import sys
sys.path.append('..')
from unet.model import UNet
from deconoising.utils import PSNR
from deconoising import utils
from deconoising import prediction

# See if we can use a GPU
device=utils.getDevice()

### Load Data

In [None]:
# We need the training data in order to calulate 'mean' and 'std' for normalization
fpath='/home/ubuntu/ashesh/data/Flywing/Flywing_n10/test/test_data.npz'
# Load the test data
data_dict = np.load(fpath)
X_test = data_dict['X_test']
# X_train = data_dict['X_train']
# X_val = data_dict['X_val']


In [None]:
from deconoising.synthetic_data_generator import PSFspecify, create_dataset

psf_list = [PSFspecify(81,1)]
convolved_data = create_dataset(torch.Tensor(X_test[:,None]), psf_list).numpy()

### Load the Network
Ensure that ```dataName``` is set same as in ```Convallaria-Training.ipynb```.

In [None]:
# Load the network, created in the 'Convallaria-Training.ipynb' notebook
net = torch.load(f"/home/ubuntu/ashesh/data/Flywing/Flywing_n10/train/best_N2V.net")

In [None]:
dataTest =convolved_data[:,0]

### Evaluation

In [None]:
index = 1
im = dataTest[index]
deconvolvedResult, denoisedResult = tiledPredict(im, net ,ps=256, overlap=48, device=device)

In [None]:
X_test.shape

In [None]:
_,ax = plt.subplots(figsize=(10,5),ncols=2)
ax[0].imshow(X_test[0,100:200,100:200])
ax[1].imshow(X_test[20,100:200,100:200])


In [None]:
crop_sz = 64
h = np.random.randint(im.shape[0] - crop_sz)
w = np.random.randint(im.shape[1] - crop_sz)
print(h,w,crop_sz)
_,ax = plt.subplots(figsize=(20,5),ncols=4)
ax[0].imshow(im[h:h+crop_sz,w:w+crop_sz])
if denoisedResult is not None:
    ax[1].imshow(denoisedResult[h:h+crop_sz,w:w+crop_sz])

ax[2].imshow(X_test[index][h:h+crop_sz,w:w+crop_sz])
ax[3].imshow(deconvolvedResult[h:h+crop_sz,w:w+crop_sz])

In [None]:
# Now we are processing data and calculating PSNR values.

psnr_result = []
psnr_input = []

# We iterate over all test images.
for index in range(dataTest.shape[0]):
    
    im = dataTest[index]
    
    # We are using tiling to fit the image into memory
    # If you get an error try a smaller patch size (ps)
    # Here we are predicting the deconvolved and denoised image
    deconvolvedResult, denoisedResult = tiledPredict(im, net ,ps=256, overlap=48, device=device)
    
    gt = dataTestGT[0] # The ground truth is the same for all images
    # calculate PSNR
    rangePSNR = np.max(gt) - np.min(gt)
    psnr_result.append(PSNR(gt, denoisedResult, rangePSNR))
    psnr_input.append(PSNR(gt, im, rangePSNR)) 
    print ("image:", index)
    print ("PSNR input", PSNR(gt, im, rangePSNR))
    print ("PSNR denoised", PSNR(gt, denoisedResult, rangePSNR)) 
    print ('-----------------------------------')
    
# We display the results for the last test image       
vmi=np.percentile(gt,0.01)
vma=np.percentile(gt,99)

plt.figure(figsize=(15, 15))
plt.subplot(1, 3, 1)
plt.title('Input image')
plt.imshow(im, vmax=vma, vmin=vmi, cmap='magma')

plt.subplot(1, 3, 2)
plt.title('Deconv. output')
plt.imshow(deconvolvedResult, vmax=vma, vmin=vmi, cmap='magma')

plt.subplot(1, 3, 3)
plt.title('Denoised output')
plt.imshow(denoisedResult, vmax=vma, vmin=vmi, cmap='magma')

plt.figure(figsize=(15, 15))
plt.subplot(1, 3, 1)
plt.title('Input image')
plt.imshow(im[100:200,150:250], vmax=vma, vmin=vmi, cmap='magma')

plt.subplot(1, 3, 2)
plt.title('Deconv. output')
plt.imshow(deconvolvedResult[100:200,150:250], vmax=vma, vmin=vmi, cmap='magma')

plt.subplot(1, 3, 3)
plt.title('Denoised output')
plt.imshow(denoisedResult[100:200,150:250], vmax=vma, vmin=vmi, cmap='magma')
print("Avg PSNR input:", np.mean(np.array(psnr_input)),  '+-(2SEM)', 2*np.std(np.array(psnr_input))/np.sqrt(float(len(psnr_input))))
print("Avg PSNR denoised:", np.mean(np.array(psnr_result)),  '+-(2SEM)', 2*np.std(np.array(psnr_result))/np.sqrt(float(len(psnr_result))))