In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import matplotlib.pyplot as plt
%matplotlib inline
#from ipywidgets import interact, interactive, fixed, interact_manual
from IPython.display import clear_output

import os, sys, glob, cv2, hdf5storage, time
import torch.nn as nn

from torchvision import transforms
import scipy.io

import models.dataset as ds
import helper as hp

import matplotlib as mpl
mpl.rc('image', cmap='inferno')


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = 'cuda:0'
dtype = torch.cuda.FloatTensor

In [None]:
!gpustat

# MultiWienerNet 3D Deconvolution Demo

In this Jupyter Notebook, we take a pretrained MultiWienerNet and demonstrate fast spatially-varying deconvolutions using both simulated and real data. We compare the performance against a pre-trained U-Net, WienerNet (non-spatially-varying), and spatially-varying FISTA. 

## Load in saved models

In [None]:
# Filepaths to saved models
multiwiener_file_path='saved_models/trained_multiwiener3D/'
unet_file_path='saved_models/trained_unet3D/'
wiener_file_path='saved_models/trained_wiener3D/'

In [None]:
unet_model = hp.load_pretrained_model(unet_file_path,model_type = 'unet', device = device)
wiener_model = hp.load_pretrained_model(wiener_file_path, model_type = 'wiener', device = device)
multiwiener_model = hp.load_pretrained_model(multiwiener_file_path, model_type = 'multiwiener', device = device)

## Load in data 

In [None]:
## CLEAN UP
down_size = ds.downsize(ds=.75)
to_tensor = ds.ToTensor()
add_noise=ds.AddNoise()

filepath_gt = '../data/3D_data_simulated/'

filepath_all=glob.glob(filepath_gt+'*')
filepath_test=filepath_all

dataset_test = ds.MiniscopeDataset(filepath_test, transform = transforms.Compose([down_size,add_noise,to_tensor]))

## Run deconvolution for simulated data

### Load in measurement

In [None]:
img_ind = 1   # We provide 2 sample images: 0 and 1 
sample_batched = dataset_test.__getitem__(img_ind)
meas_np = hp.to_np(sample_batched['meas'])
sample_batched['meas'] = sample_batched['meas'].unsqueeze(0)

plt.imshow(meas_np);
plt.title('measurement');
print('measurement shape:', meas_np.shape)

### Deconvolve! 

In [None]:
t_list = []
with torch.no_grad():
    t0 = time.time()
    out_unet = unet_model(sample_batched['meas'].repeat(1,1,32,1,1).to(device))
    t_list.append(time.time() - t0)
    
    t0 = time.time()
    out_wiener = wiener_model((sample_batched['meas']).to(device))
    t_list.append(time.time() - t0)
    
    t0 = time.time()
    out_multiwiener = multiwiener_model((sample_batched['meas']).to(device))
    t_list.append(time.time() - t0)
    
recon_titles = ['Unet', 'WienerNet', 'MultiWienerNet (Ours)']
recon_list = [out_unet, out_wiener, out_multiwiener]

### Plot results

In [None]:
gt_np = hp.to_np(sample_batched['im_gt'].unsqueeze(0))
recons_np = []
for i in range(0,len(recon_list)):
    recons_np.append(hp.to_np(recon_list[i]))

f, ax = plt.subplots(1, 4, figsize=(15,15))
ax[0].imshow(hp.max_proj(gt_np))
ax[0].set_title('Ground Truth')
for i in range(0,len(recons_np)):
    ax[i+1].imshow(hp.max_proj(recons_np[i]))
    ax[i+1].set_title(recon_titles[i])
    
for i in range(0,len(recons_np)):
    print(recon_titles[i], ': ', np.round(t_list[i],2),'s,  PSNR: ', np.round(hp.calc_psnr(gt_np, recons_np[i]),2))

In [None]:
out_np = recons_np[-1]
def plot_slider(x):
    f, ax = plt.subplots(1, 4, figsize=(15,15))
    plt.title('Reconstruction: frame %d'%(x))
   
    ax[0].imshow(gt_np[x],vmin=0, vmax=np.max(gt_np))
    ax[0].set_title('Ground Truth, frame %d'%(x))
    ax[0].axis('off')
    for i in range(0,len(recons_np)):
        ax[i+1].imshow(recons_np[i][x], vmin=0, vmax=np.max(recons_np[i]))
        ax[i+1].set_title(recon_titles[i])
        ax[i+1].axis('off')
        
    return x


interactive(plot_slider,x=(0,out_np.shape[0]-1,1))

### Compare against spatially-varying FISTA

In [None]:
#compare to fista
saved_fista = [ 'fista3D-fourCells.mat', 'fista3D-cellcool.mat',]

Ifista=scipy.io.loadmat('../data/' + saved_fista[img_ind])
Ifista=Ifista['xhat_out']
Ifista=Ifista.transpose([2,0,1])/np.max(Ifista)

f, ax = plt.subplots(1, 2, figsize=(10,5))
ax[0].imshow(hp.max_proj(Ifista))
ax[0].set_title('FISTA result')
ax[1].imshow(hp.max_proj(recons_np[-1]))
ax[1].set_title(recon_titles[-1])

print('FISTA PSNR: ', np.round(hp.calc_psnr(gt_np, Ifista),2))

## Run deconvolution for real data

In [None]:
img_ind = 0 # 0: resolution target, 1: waterbear

loaded_meas = glob.glob('../data/real_data/*')
meas_loaded = scipy.io.loadmat(loaded_meas[img_ind])['b']

In [None]:
meas_loaded.shape

In [None]:
meas=meas_loaded[18:466,4:644]
meas= cv2.resize(meas, (0,0), fx=0.75, fy=0.75) 
meas_tensor=torch.tensor(meas, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0).unsqueeze(0)
plt.imshow(meas)

In [None]:
with torch.no_grad():
    meas_t = meas_tensor.repeat(1,1,32,1,1)
    out_unet = unet_model(meas_t.to(device))
    out_wiener = wiener_model((meas_t).to(device))
    out_multiwiener = multiwiener_model((meas_t).to(device))
    
    recon_titles = ['Unet', 'WienerNet', 'MultiWienerNet (Ours)']
    recon_list = [out_unet, out_wiener, out_multiwiener]

In [None]:

with torch.no_grad():
    out_wiener = wiener_model.wiener_model(meas_t.to(device))

    out_multiwiener = multiwiener_model.wiener_model(meas_t.to(device))

    
plt.imshow(out_multiwiener[0,4,0].detach().cpu().numpy()); plt.colorbar()


In [None]:
recons_np = []
for i in range(0,len(recon_list)):
    recons_np.append(hp.to_np(recon_list[i]))

f, ax = plt.subplots(1, 3, figsize=(15,15))
for i in range(0,len(recons_np)):
    if img_ind == 0:
        ax[i].imshow(recons_np[i][1])
    else:
        ax[i].imshow(hp.max_proj(recons_np[i]))
    ax[i].set_title(recon_titles[i])

In [None]:
def plot_slider(x):
    f, ax = plt.subplots(1, 3, figsize=(15,15))
    plt.title('Reconstruction: frame %d'%(x))
   
    for i in range(0,len(recons_np)):
        ax[i].imshow(recons_np[i][x], vmin=0, vmax=np.max(recons_np[i]))
        ax[i].axis('off')
        
        if i ==0:
            ax[i].set_title('Unet, frame %d'%(x))
        else:
            ax[i].set_title(recon_titles[i])
        
    return x


interactive(plot_slider,x=(0,out_np.shape[0]-1,1))

## Run deconvolution movie for real data

In [None]:
waterbear=hdf5storage.loadmat('/media/lahvahndata/Kyrollos/LearnedMiniscope3D/real_data/waterbear_all.mat') 
waterbear=waterbear['b']
waterbear=(waterbear)

In [None]:
waterbear.shape

In [None]:
meas=waterbear[18:466,4:644,:]
meas= cv2.resize(meas, (0,0), fx=0.75, fy=0.75) 
meas_tensor=torch.tensor(meas, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0).unsqueeze(0)
plt.imshow(meas[...,0])

In [None]:
def plot_slider(x):
    plt.title('Reconstruction: frame %d'%(x))
    plt.axis('off')
    plt.imshow(meas[...,x])
    return x


interactive(plot_slider,x=(0,meas.shape[-1]-1,1))

In [None]:

out_bear_xy=[]
out_bear_yz=[]
for t in range(30):
    
    print('processing image: ', t, end='\r')
    with torch.no_grad():
        out_waterbear=multiwiener_model(meas_tensor[...,t])  #.repeat(1,1,32,1,1)
    out_waterbear_np = out_waterbear.detach().cpu().numpy()[0,0]
    
    out_bear_xy.append(np.max(out_waterbear_np,0))
    out_bear_yz.append(np.max(out_waterbear_np,2))
    
    
#     plt.imshow(out_bear_xy[-1])
#     plt.title(t)
#     plt.show()
#     clear_output(wait=True)


In [None]:
out_bear_xy=np.array(out_bear_xy)
out_bear_yz=np.array(out_bear_yz)
# test=test.transpose([1,2,0])

In [None]:
def plot_slider(x):
    f, ax = plt.subplots(1, 3, figsize=(15,3))
    
   
    ax[0].imshow(meas[...,x], vmin=0, vmax=np.max(meas))
    ax[1].imshow(out_bear_xy[x], vmin=0, vmax=np.max(out_bear_xy))
    ax[2].imshow(out_bear_yz[x].transpose(), vmin=0, vmax=np.max(out_bear_yz))
    
    ax[0].set_title('Measurement')
    ax[1].set_title('Reconstruction: frame %d'%(x))
    
    ax[0].axis('off')
    ax[1].axis('off')
    ax[2].axis('off')
        
       
    return x


interactive(plot_slider,x=(0,out_bear_xy.shape[0]-1,1))

## Visualize Learned PSFs

In [None]:
learned_psfs_wiener_np=wiener_model.wiener_model.psfs.detach().cpu().numpy()

In [None]:
def plot_slider(x):
    plt.title('Reconstruction: frame %d'%(x))
    plt.axis('off')
    plt.imshow(learned_psfs_wiener_np[x])
    return x


interactive(plot_slider,x=(0,learned_psfs_wiener_np.shape[0]-1,1))

In [None]:
learned_psfs_np=multiwiener_model.wiener_model.psfs.detach().cpu().numpy()
learned_Ks_np=multiwiener_model.wiener_model.Ks.detach().cpu().numpy()

In [None]:
def plot_slider(x):
    plt.title('Reconstruction: frame %d'%(x))
    plt.axis('off')
    plt.imshow(learned_psfs_np[4][x])
    return x


interactive(plot_slider,x=(0,learned_psfs_np.shape[1]-1,1))

In [None]:
x=20
plt.imshow(np.abs(learned_psfs_np[8][x]-learned_psfs_np[0][x])); plt.colorbar()