In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
%matplotlib inline
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import numpy as np
import torch, torch.optim
import torch.nn.functional as F
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor
import os, sys
sys.path.append('utils/*')

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import models as md
import utils.common_utils as cu
import utils.diffuser_utils as df 
import utils.utils_hyperspectral as helper


# Single-shot Imaging Demo

Load in the PSF, 2D measurement and rolling shutter mask. 

In [None]:
simulated = True  # True: Use a simulated measurement or False: use an experimental measurement 

In [None]:
downsampling_factor = 2
meas_np, mask_np, psf_np, gt_np = helper.load_data(simulated = simulated)

plt.figure(figsize=(20,10))    
plt.subplot(1,3,1);plt.title('PSF');plt.imshow(psf_np)
plt.subplot(1,3,2);plt.title('Measurement');plt.imshow(meas_np)
plt.subplot(1,3,3);plt.title('Rolling shutter mask');plt.imshow(mask_np[:,:,20]) 

Initialize the lensless forward model

In [None]:
DIMS0 = meas_np.shape[0]  # Image Dimensions
DIMS1 = meas_np.shape[1]  # Image Dimensions

py = int((DIMS0)//2)                           # Pad size
px = int((DIMS1)//2)                           # Pad size

def pad(x):
    if len(x.shape) == 2: 
        out = np.pad(x, ([py, py], [px,px]), mode = 'constant')
    elif len(x.shape) == 3:
        out = np.pad(x, ([py, py], [px,px], [0, 0]), mode = 'constant')
    elif len(x.shape) == 4:
        out = np.pad(x, ([py, py], [px,px], [0, 0], [0, 0]), mode = 'constant')
    return out


#meas_np = pad(meas_np)
psf_pad = pad(psf_np)

h_full = np.fft.fft2(np.fft.ifftshift(psf_pad))

In [None]:
forward = df.Forward_Model_combined(h_full, 
                                    shutter = mask_np, 
                                    imaging_type = 'spectral')

In [None]:
if simulated == True:
    meas_torch = forward(cu.np_to_torch(gt_np.transpose(2,0,1)).type(dtype).unsqueeze(0))
    meas_np = cu.torch_to_np(meas_torch)[0]
    plt.imshow(meas_np)

Set up parameters and network

In [None]:
# Define network hyperparameters: 
input_depth = 32
INPUT =     'noise'
pad   =     'reflection'
LR = 1e-3
tv_weight = 0
reg_noise_std = 0.05


if simulated == True:
    num_iter = 100000
    net_input = cu.get_noise(input_depth, INPUT, (meas_np.shape[0], meas_np.shape[1])).type(dtype).detach()
else:
    num_iter = 4600
    input_depth = 1
    net_input = cu.get_noise(input_depth, INPUT, (mask_np.shape[-1], meas_np.shape[0], meas_np.shape[1])).type(dtype).detach()

    
# Initialize network input 

net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

# reinitialize netowrk and optimizer
if simulated == True:
    NET_TYPE = 'skip' 
    net = md.get_net(input_depth, NET_TYPE, pad, n_channels=32, skip_n33d=128,  skip_n33u=128,  skip_n11=4,  num_scales=5,upsample_mode='bilinear').type(dtype)
else:
    print('experimental')
    NET_TYPE = 'skip3D' 
    input_depth = 1
    net = md.get_net(input_depth, NET_TYPE, pad, n_channels=1, skip_n33d=128,  skip_n33u=128,  skip_n11=4,  num_scales=4,upsample_mode='trilinear').type(dtype)

#NET_TYPE = 'skip' 
#net = md.get_net(input_depth, NET_TYPE, pad, n_channels=32, skip_n33d=128,  skip_n33u=128,  skip_n11=4,  num_scales=5,upsample_mode='bilinear').type(dtype)
p = [x for x in net.parameters()]
optimizer = torch.optim.Adam(p, lr=LR)

# Losses
mse = torch.nn.MSELoss().type(dtype)

def main():
    global recons 
    full_recons = []
    meas_ts = cu.np_to_ts(meas_np)
    meas_ts = meas_ts.detach().clone().type(dtype).cuda()

    for i in range(num_iter):
        optimizer.zero_grad()
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
        recons = net(net_input)
        gen_meas = forward.forward(recons)
        gen_meas = F.normalize(gen_meas, dim=[1,2], p=2)
        loss = mse(gen_meas, meas_ts)
        loss += tv_weight * df.tv_loss(recons)
        loss.backward()
        print('Iteration %05d, loss %.8f '%(i, loss.item()), '\r',  end='')
        if i % 100 == 0:
            helper.plot(recons)
            print('Iteration {}, loss {:.8f}'.format(i, loss.item()))
        optimizer.step()
    full_recons = helper.preplot(recons)
    return full_recons

### Run the reconstruction

In [None]:
full_recons = main()

In [None]:
full_recons = helper.preplot2(recons)

Reconstructed video

In [None]:
def plot_slider(x):
    plt.title('Reconstruction: frame %d'%(x))
    plt.axis('off')
    plt.imshow(full_recons[...,x])
    return x

interactive(plot_slider,x=(0,full_recons.shape[-1]-1,1))