In [None]:
import array_api_compat.torch as xp
import torch
import parallelproj
from array_api_compat import to_device
import array_api_compat.numpy as np
import matplotlib.pyplot as plt
import time

dict_data = np.load("large_regular_polygon_geom.npy", allow_pickle=True).item()

xstart = xp.asarray(dict_data['xstart']).cuda()
xend = xp.asarray(dict_data['xend']).cuda()
img_origin = xp.asarray(dict_data['img_origin']).cuda()
voxel_size = xp.asarray(dict_data['voxel_size']).cuda()
img_shape = dict_data['img_shape']
x_gt = xp.asarray(dict_data['x']).cuda()
meas = xp.asarray(dict_data['meas']).cuda()

fwd = lambda x: parallelproj.joseph3d_fwd(xstart, xend, x, img_origin, voxel_size)

back = lambda y: parallelproj.joseph3d_back(xstart, xend, img_shape, img_origin, voxel_size, y)

print("ParallelProj Forward...")
start = time.time()
for i in range(10):
    tmp = fwd(x_gt)
    torch.cuda.synchronize(x_gt.device)
end_t = time.time()
print(f"ParallelProj Time: {(end_t - start):.6f} s")

print("ParallelProj Backward...")
start = time.time()
for i in range(10):
    tmp = back(meas)
    torch.cuda.synchronize(x_gt.device)
end_t = time.time()
print(f"ParallelProj Time: {(end_t - start):.6f} s")

sens_img = back(xp.ones_like(meas))

x = xp.ones_like(sens_img)

start = time.time()
for i in range(100):
    em_precond = x/sens_img
    ratio = meas/(fwd(x)+100.)
    torch.cuda.synchronize(x_gt.device)
    bp_ratio = back(ratio)
    torch.cuda.synchronize(x_gt.device)
    x = em_precond*bp_ratio
end_t = time.time()
print(f"ParallelProj 100 MLEM: {end_t - start:.6f} s")


import matplotlib.pyplot as plt
from ipywidgets import interact
def show_slices(slice_idx):
    plt.imshow(x[:,slice_idx,:].cpu(), cmap='gray')
    plt.colorbar()
    plt.show()

xp.save(x, "parallel_proj_recon.torch")
interact(show_slices, slice_idx=(0, x.shape[1]-1, 1))

In [None]:
# adjoint-ness test

x_ = xp.ones_like(x_gt)/100
y = xp.ones_like(meas)/100

inner_1 = xp.sum(fwd(x_).ravel().double()*y.ravel().double())
inner_2 = xp.sum(back(y).ravel().double()*x_.ravel().double())
print(inner_1)
print(inner_2)
mean_inner = (inner_1+inner_2)/2

print(f"The difference is {xp.abs(inner_1-inner_2)/mean_inner*100:.3e}")