# Imports

In [None]:
import sys
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import os
import contexttimer
%matplotlib notebook
%load_ext autoreload
%autoreload 2
import arrayfire as af
af.set_device(0)

## Specify the path of your repo

In [None]:
from opticaltomography.opticsutil import compare3DStack, show3DStack
from opticaltomography.opticsalg import PhaseObject3D, TomographySolver, AlgorithmConfigs

# Specify parameters & load data

In [None]:
# Units in microns
wavelength = 0.514
n_measure = 1.0
n_b = 1.0
maginification = 80.
dx = 6.5 / maginification
dy = 6.5 / maginification
dz = 3 * dx
na = 0.65

In [None]:
#Make sure the path is correct
#Illumination angle, change to [0.0] if only on-axis is needed:
na_list = sio.loadmat("na_list_test.mat")
fx_illu_list = na_list["na_list"][150:,0] / wavelength
fy_illu_list = na_list["na_list"][150:,1] / wavelength

# Plot object in z (y,x,z)

In [None]:
phantom = np.ones((400,400,10),dtype="complex64") * n_b
show3DStack(np.real(phantom), axis=2, clim=(np.min(np.real(phantom)), np.max(np.real(phantom))))

## Fill in phantom

In [None]:
x, y = np.meshgrid(np.linspace(-1,1,phantom.shape[0]), np.linspace(-1,1,phantom.shape[1]))
r2 = x ** 2 + y ** 2
phantom[...,4] += (r2 < 0.25 ** 2) * 0.1 / (2 * np.pi * dz / wavelength)
show3DStack(np.real(phantom), axis=2, clim=(np.min(np.real(phantom)), np.max(np.real(phantom))))

# Setup solver objects

In [None]:
solver_params = dict(wavelength = wavelength, na = na, \
                     RI_measure = n_measure, sigma = 2 * np.pi * dz / wavelength,\
                     fx_illu_list = fx_illu_list, fy_illu_list = fy_illu_list,\
                     pad = True, pad_size = (25,25))
phase_obj_3d = PhaseObject3D(shape=phantom.shape, voxel_size=(dy,dx,dz), RI=n_b, RI_obj=phantom)
solver_obj   = TomographySolver(phase_obj_3d, **solver_params)
# Forward simulation method
# solver_obj.setScatteringMethod(model = "MultiPhaseContrast")
solver_obj.setScatteringMethod(model = "MultiBorn")

# Generate forward prediction

In [None]:
with contexttimer.Timer() as timer:
    forward_field_mb = solver_obj.forwardPredict(field=False)
    print(timer.elapsed)   
forward_field_mb = np.squeeze(forward_field_mb)    

In [None]:
#plot
%matplotlib notebook
show3DStack(np.real(forward_field_mb), axis=2, clim=(np.min(np.real(forward_field_mb)), np.max(np.real(forward_field_mb))))

# Solving an inverse problem

In [None]:
#Create a class for all inverse problem parameters
configs            = AlgorithmConfigs()
configs.batch_size = 1
configs.method     = "FISTA"
configs.restart    = True
configs.max_iter   = 5
# multislice stepsize
# configs.stepsize   = 2e-4
# multiborn stepsize
configs.stepsize   = 10
configs.error      = []
configs.pure_real = True
#total variation regularization
configs.total_variation     = False
configs.reg_tv              = 1.0 #lambda
configs.max_iter_tv         = 15
configs.order_tv            = 1
configs.total_variation_gpu = True
configs.total_variation_anisotropic = False

# reconstruction method
# solver_obj.setScatteringMethod(model = "MultiPhaseContrast")
solver_obj.setScatteringMethod(model = "MultiBorn")

In [None]:
recon_obj_3d = solver_obj.solve(configs, forward_field_mb)

## Plotting results

In [None]:
current_rec = recon_obj_3d
cost = solver_obj.configs.error
show3DStack(np.real(current_rec), axis=2, clim=(np.min(np.real(current_rec)), np.max(np.real(current_rec))))

In [None]:
plt.figure()
plt.plot(np.log10(cost))