Choose a sample to run. `sample_id` can be one of:
- mouse_vas_deferens1
- mouse_vas_deferens2
- mouse_femoral_artery
- mouse_bladder
- mouse_trachea
- human_cornea
- insect_leg

In [1]:
sample_id = 'mouse_bladder'  # a string that denotes which data to run
data_directory = 'data/'  # where is the OCRT input data located?
save_directory = 'saved_models_and_variables/'  # where to save the tf graph after optimization

# Registration of multiangle B-scans and synthesis of refractive index map
The registration metric is optimized with respect to the deformation model, which in this case is ray propagation through a spatially inhomogeneous refractive index map.

In [2]:
from __future__ import print_function, division
from OCRT import OCRT2D
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from time import time
%matplotlib inline

a = OCRT2D(sample_id=sample_id,save_directory=save_directory)

In [3]:
# adjust some parameters, depending on the sample:
if sample_id in ['mouse_bladder', 'mouse_femoral_artery', 'mouse_trachea', 'mouse_vas_deferens1']:
    num_iter = 200
elif sample_id in ['mouse_vas_deferens2']:
    num_iter = 300
elif sample_id in ['insect_leg']:
    num_iter = 500
elif sample_id in ['human_cornea']:
    num_iter = 500
    # this sample is more difficult to register, so use a multiresolution approach:
    a.use_multires = True
    a.size_factor_ = 1  # the final size_factor_ is 8
    a.switch_iter = 250
else:
    raise Exception('invalid sample_id: ' + sample_id)
    
# two tube sizes were used (same inner diameter, but different outer diameter):
if sample_id in ['mouse_vas_deferens1', 'mouse_vas_deferens2', 'mouse_trachea']:
    a.tube_diameter = 1.066  # in mm
else:
    a.tube_diameter = 1.108516

In [4]:
a.load_data_and_resolve_constants(data_directory=data_directory)
a.build_graph()

data loaded: 10.488995552062988 sec


AttributeError: module 'tensorflow._api.v1.config' has no attribute 'list_physical_devices'

In [None]:
# run through optimization loop:
losses = list()
feed_dict = a.get_feed_dict()
for i in range(num_iter + 1):
    
    # if using multires, change the pixel resolution of the reconstruction
    if i == a.switch_iter and a.use_multires: 
        feed_dict[a.size_factor] = a.final_size_factor
    
    start = time()
    loss_i, _ = a.sess.run([a.loss_terms, a.train_op], feed_dict=feed_dict)
    losses.append(loss_i)
    print(i, loss_i, time()-start)
    # loss_i is a list of all the contributors to the scalar loss;
    # (see a.loss_terms or a.loss_term_names for names of the regularization terms)
    
    # monitor results periodically:
    if i % 10 == 0:
        recon_i = a.sess.run(a.recon, feed_dict=feed_dict)
        recon_i = recon_i.sum(2)  # only once slice along y contains nonzero values, because we are optimizing 2D
        plt.figure(figsize=(10, 10))
        plt.imshow(recon_i, cmap='gray_r')
        plt.title('OCRT reconstruction')
        plt.show()
        
        RI = a.sess.run(a.RI, feed_dict={a.xz_delta: np.zeros((60, 2))})  # remove xy_delta shifts
        
        plt.imshow(RI)
        plt.title('refractive index map')
        plt.colorbar()
        plt.show()
        
        plt.plot(losses)
        plt.legend(a.loss_term_names.eval())
        plt.title('loss terms')
        plt.show()

InternalError: 2 root error(s) found.
  (0) Internal: Blas xGEMMBatched launch failed : a.shape=[60,36,2], b.shape=[60,2,2], m=36, n=2, k=2, batch_size=60
	 [[node ray_propagation/scan/while/MatMul (defined at c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\framework\ops.py:1748) ]]
	 [[backprojection/Unique/_63]]
  (1) Internal: Blas xGEMMBatched launch failed : a.shape=[60,36,2], b.shape=[60,2,2], m=36, n=2, k=2, batch_size=60
	 [[node ray_propagation/scan/while/MatMul (defined at c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\framework\ops.py:1748) ]]
0 successful operations.
0 derived errors ignored.

Original stack trace for 'ray_propagation/scan/while/MatMul':
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\traitlets\config\application.py", line 992, in launch_instance
    app.start()
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\ipykernel\kernelapp.py", line 712, in start
    self.io_loop.start()
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tornado\platform\asyncio.py", line 215, in start
    self.asyncio_loop.run_forever()
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\asyncio\base_events.py", line 541, in run_forever
    self._run_once()
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\asyncio\base_events.py", line 1786, in _run_once
    handle._run()
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\asyncio\events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\ipykernel\kernelbase.py", line 510, in dispatch_queue
    await self.process_one()
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\ipykernel\kernelbase.py", line 499, in process_one
    await dispatch(*args)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\ipykernel\kernelbase.py", line 406, in dispatch_shell
    await result
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\ipykernel\kernelbase.py", line 730, in execute_request
    reply_content = await reply_content
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\ipykernel\ipkernel.py", line 390, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\ipykernel\zmqshell.py", line 528, in run_cell
    return super().run_cell(*args, **kwargs)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\IPython\core\interactiveshell.py", line 2915, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\IPython\core\interactiveshell.py", line 2960, in _run_cell
    return runner(coro)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\IPython\core\async_helpers.py", line 78, in _pseudo_sync_runner
    coro.send(None)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\IPython\core\interactiveshell.py", line 3186, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\IPython\core\interactiveshell.py", line 3377, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\IPython\core\interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\alexa\AppData\Local\Temp\ipykernel_9924\4203457682.py", line 2, in <module>
    a.build_graph()
  File "c:\Users\alexa\Desktop\University\LINUM\optical-coherence-refraction-tomography\OCRT.py", line 299, in build_graph
    self.create_losses()
  File "c:\Users\alexa\Desktop\University\LINUM\optical-coherence-refraction-tomography\OCRT.py", line 425, in create_losses
    return_derivs=True)
  File "c:\Users\alexa\Desktop\University\LINUM\optical-coherence-refraction-tomography\OCRT.py", line 714, in integrate_scan_opl
    paths = tf.scan(self.rk4_step_opl, dummy, xz_init0, swap_memory=True)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\ops\functional_ops.py", line 508, in scan
    maximum_iterations=n)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\ops\control_flow_ops.py", line 2753, in while_loop
    return_same_structure)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\ops\control_flow_ops.py", line 2245, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\ops\control_flow_ops.py", line 2170, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\ops\control_flow_ops.py", line 2705, in <lambda>
    body = lambda i, lv: (i + 1, orig_body(*lv))
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\ops\functional_ops.py", line 485, in compute
    a_out = fn(packed_a, packed_elems)
  File "c:\Users\alexa\Desktop\University\LINUM\optical-coherence-refraction-tomography\OCRT.py", line 682, in rk4_step_opl
    (deriv, n) = self.rayeq_opl(z0, x0)
  File "c:\Users\alexa\Desktop\University\LINUM\optical-coherence-refraction-tomography\OCRT.py", line 668, in rayeq_opl
    (n, dndx, dndz) = self.indexdist(x[:, :, 0], z)
  File "c:\Users\alexa\Desktop\University\LINUM\optical-coherence-refraction-tomography\OCRT.py", line 610, in indexdist
    xzp = tf.matmul(xz, tf.transpose(self.rotmats, [0, 2, 1]))
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\util\dispatch.py", line 180, in wrapper
    return target(*args, **kwargs)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\ops\math_ops.py", line 2716, in matmul
    return batch_mat_mul_fn(a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\ops\gen_math_ops.py", line 1712, in batch_mat_mul_v2
    "BatchMatMulV2", x=x, y=y, adj_x=adj_x, adj_y=adj_y, name=name)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\framework\op_def_library.py", line 794, in _apply_op_helper
    op_def=op_def)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\framework\ops.py", line 3357, in create_op
    attrs, op_def, compute_device)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\framework\ops.py", line 3426, in _create_op_internal
    op_def=op_def)
  File "c:\Users\alexa\anaconda3\envs\env3.7\lib\site-packages\tensorflow_core\python\framework\ops.py", line 1748, in __init__
    self._traceback = tf_stack.extract_stack()


In [None]:
a.save_graph()  # this graph must be saved in order to run filter optimization below

# Filter optimization after registration
Freeze the registration/refractive index parameters and only optimize the 2D backprojection filter.

In [None]:
# remove previous tf graph and variables:
a.sess.close()
tf.reset_default_graph()
del a

In [None]:
# instantiate new object, this time for filter optimization
a = OCRT2D(sample_id=sample_id,save_directory=save_directory)
num_iter = 100
a.infer_backprojection_filter = True
a.use_spatial_shifts = False

# as above, set the tube diameter depending on the sample:
if sample_id in ['mouse_vas_deferens1', 'mouse_vas_deferens2', 'mouse_trachea']:
    a.tube_diameter = 1.066
else:
    a.tube_diameter = 1.108516

a.load_data_and_resolve_constants(data_directory=data_directory)
a.build_graph()

In [None]:
losses = list()
feed_dict = a.get_feed_dict()
for i in range(num_iter + 1):
    start = time()
    loss_i, _ = a.sess.run([a.loss_terms, a.train_op], feed_dict=feed_dict)
    losses.append(loss_i)
    print(i, loss_i, time()-start)
    
    # monitor results periodically:
    if i % 20 == 0:
        recon_i = a.sess.run(a.recon, feed_dict=feed_dict)
        recon_i = recon_i.sum(2)
        plt.figure(figsize=(10, 10))
        plt.imshow(recon_i, cmap='gray_r')
        plt.title('OCRT reconstruction')
        plt.show()
        
        plt.plot(losses)
        plt.legend(a.loss_term_names.eval())
        plt.title('loss terms')
        plt.show()

In [None]:
# a.save_graph()  # this graph doesn't need to be saved