In [None]:
from triton_joseph_proj import *
import numpy as np
import matplotlib.pyplot as plt
import time
import array_api_compat.torch as xp

dict_data = np.load("large_regular_polygon_geom.npy", allow_pickle=True).item()

xstart = xp.asarray(dict_data['xstart']).cuda().float()
xend = xp.asarray(dict_data['xend']).cuda().float()
img_origin = xp.asarray(dict_data['img_origin']).cuda().float()
voxel_size = xp.asarray(dict_data['voxel_size']).cuda().float()
img_shape = dict_data['img_shape']
x_gt = xp.asarray(dict_data['x']).cuda().float()
meas = xp.asarray(dict_data['meas']).cuda().float()

fwd = lambda x: joseph3d_fwd_vec(xstart.reshape(-1,3), xend.reshape(-1,3), x, img_origin, voxel_size, img_shape).reshape(meas.shape)

back = lambda y: joseph3d_back_vec(xstart.reshape(-1,3), xend.reshape(-1,3), y.ravel(), img_origin, voxel_size, img_shape)

class Joseph3D(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return fwd(x)
    @staticmethod
    def backward(ctx, grad_output):
        return back(grad_output)

class Joseph3DAdjoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return back(x)
    @staticmethod
    def backward(ctx, grad_output):
        return fwd(grad_output)


tmp_meas = fwd(x_gt)
tmp_bp = back(meas)


print("Triton Forward...")
start = time.time()
for i in range(10):
    tmp = fwd(x_gt)
end_t = time.time()
print(f"Triton Time: {(end_t - start):.6f} s")

print("Triton Backward...")
start = time.time()
for i in range(10):
    tmp = back(meas)
end_t = time.time()
print(f"Triton Time: {(end_t - start):.6f} s")

x = back(meas)

sens_img = back(xp.ones_like(meas))

x = xp.ones_like(sens_img)

start = time.time()
for i in range(100):
    em_precond = x/sens_img
    ratio = meas/(fwd(x)+100.)
    bp_ratio = back(ratio)
    x = em_precond*bp_ratio
end_t = time.time()
print(f"Triton 100 MLEM: {end_t - start:.6f} s")

import matplotlib.pyplot as plt
from ipywidgets import interact
def show_slices(slice_idx):
    plt.imshow(x[:,slice_idx,:].cpu(), cmap='gray')
    plt.colorbar()
    plt.show()

interact(show_slices, slice_idx=(0, x.shape[1]-1, 1))

In [None]:
print("Number of LORs: ", xstart[...,0].ravel().shape)
print("Sinogram dims: ", xstart[...,0].shape)
print("The shape of the image: ", x.shape)

In [None]:
# adjoint-ness test

_x = xp.ones_like(x_gt)
y = xp.ones_like(meas)

inner_1 = xp.sum(fwd(_x).ravel().double()*y.ravel().double())
inner_2 = xp.sum(back(y).ravel().double()*_x.ravel().double())
print(inner_1)
print(inner_2)
mean_inner = (inner_1+inner_2)/2
print(f"The difference is {xp.abs(inner_1-inner_2)/mean_inner*100:.3e}")

In [None]:
x_parallel_proj = torch.load("parallel_proj_recon.torch")

import matplotlib.pyplot as plt
from ipywidgets import interact
def show_slices(slice_idx):
    plt.figure(figsize=(18, 6), dpi=80)
    plt.subplot(1,3,1)
    plt.imshow(x[:,slice_idx,:].cpu(), cmap='gray')
    plt.colorbar()
    plt.subplot(1,3,2)
    plt.imshow(x_parallel_proj[:,slice_idx,:].cpu(), cmap='gray')
    plt.colorbar()
    plt.subplot(1,3,3)
    plt.imshow((x[:,slice_idx,:].cpu()-x_parallel_proj[:,slice_idx,:].cpu()).pow(2), cmap='gray')
    plt.colorbar()
    plt.show()
    
interact(show_slices, slice_idx=(0, x.shape[1]-1, 1))