In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np
import torch
import torch.optim
torch.backends.cudnn.enabled = True
dtype = torch.cuda.FloatTensor
import torch.nn.functional as F

import glob, os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import utils.diffuser_utils as df 
import utils.common_utils as cu
import models as md
import utils.utils_2d as helper  # helper functions 

# 2D imaging Demo

Choose whether to reconstruct from a simulated measurement or an experimental measurement, and an erasure rate 

In [None]:
simulation = False # False for experimental data and True for similated data
erasure_rate = 0.5 #from 0 to 1, with 0 equivalent to full measurement

### Load data

Load PSF, create random erasure and initialize forward model

In [None]:
downsample_f = 4
psf_np = helper.load_psf('./data/2d_imaging/psf.tiff', downsample_f)
erasure_np = helper.get_eraser(psf_np, erasure_rate)  
forward = df.Forward_Model(np.sum(psf_np,2), erasure_np)

Load groundtruth, lensless measurement and apply erasure on measurement

In [None]:
img_index =7   #2:crab 8:bottled caps
file_path_diffuser = 'data/2d_imaging/diffuser/'
file_path_lensed = 'data/2d_imaging/lensed/'
files = glob.glob(file_path_diffuser + '/*.npy')

lensed_img = helper.guass_fn(helper.load_img(file_path_lensed+files[img_index].split('/')[-1]))

if simulation is True:
    print('In simulation {}% erasure'.format(erasure_rate*100))
    img_meas = forward.forward_zero_pad(cu.np_to_ts(lensed_img.transpose(2,0,1)).type(dtype))
    img_meas /= torch.max(img_meas)
    img_meas_np = cu.ts_to_np(img_meas).transpose(1,2,0)
    img_meas_np /= np.max(img_meas_np)
else:
    print('In experiment {}% erasure'.format(erasure_rate*100))
    lensless_img = helper.load_img(file_path_diffuser+files[img_index].split('/')[-1])
    img_meas = cu.np_to_ts(lensless_img.transpose(2,0,1)).type(dtype)
    img_meas, img_meas_np = helper.apply_eraser(erasure_np, img_meas.cuda(), lensless_img)
    
plt.figure(figsize=(20,10))  
plt.subplot(1,3,1);plt.title('Groundtruth');plt.imshow(lensed_img)
plt.subplot(1,3,2);plt.title('PSF');plt.imshow(psf_np/np.max(psf_np))
plt.subplot(1,3,3);plt.title('Measurement');plt.imshow(img_meas_np)

Set up parameters and network

In [None]:
## Define network hyperparameters: 
input_depth = 80
INPUT =     'noise'
pad   =     'reflection'
LR = 1e-3
tv_weight = 1e-20
num_iter = 20000
reg_noise_std = 0.05

## initialize network input (noise)
net_input = cu.get_noise(input_depth, INPUT, (img_meas_np.shape[0]*2, img_meas_np.shape[1]*2)).type(dtype).detach()
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

## initialize netowrk
NET_TYPE = 'skip' # UNet, ResNet
net = md.get_net(input_depth, NET_TYPE, pad,skip_n33d=128,  skip_n33u=128,  skip_n11=4,  num_scales=5,upsample_mode='bilinear').type(dtype)

# Losses
mse = torch.nn.MSELoss().type(dtype)

p = [x for x in net.parameters()]
optimizer = torch.optim.Adam(p, lr=LR)

def main():
    for i in range(num_iter):
        global recons
        optimizer.zero_grad()

        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
        recons = net(net_input)
        gen_meas = forward.forward(recons)
        gen_meas = F.normalize(gen_meas, dim=[1,2,3], p=2)
        loss = mse(gen_meas, img_meas)
        loss += tv_weight * df.tv_loss(recons)
        loss.backward()
        print ('Iteration %05d, loss %.15f ' % (i, loss.item()), '\r', end='') 

        if i % 500 == 0:
            helper.plot(lensed_img, recons)
        optimizer.step()
    return recons

## Run the reconstruction

In [None]:
recons = main()

In [None]:
helper.plot(lensed_img, recons)