In [None]:
from triton_joseph_proj import *
import numpy as np
import matplotlib.pyplot as plt
import time
import array_api_compat.torch as xp

dict_data = np.load("/home/user/triton/slices_8_lors_40.npy", allow_pickle=True).item()

xstart = xp.asarray(dict_data['xstart']).cuda()
xend = xp.asarray(dict_data['xend']).cuda()
img_origin = xp.asarray(dict_data['img_origin']).cuda()
voxel_size = xp.asarray(dict_data['voxel_size']).cuda()
img_shape = dict_data['img_shape']
x_gt = xp.asarray(dict_data['x']).cuda()
meas = xp.asarray(dict_data['meas']).cuda() + 100

fwd = lambda x: joseph3d_fwd_vec(xstart.reshape(-1,3), xend.reshape(-1,3), x, img_origin, voxel_size, img_shape).reshape(meas.shape)
back = lambda y: joseph3d_back_vec(xstart.reshape(-1,3), xend.reshape(-1,3), y.ravel(), img_origin, voxel_size, img_shape)


# cache hits
fwd(x_gt).cpu()
back(xp.asarray(dict_data['meas']).cuda()).cpu()

xp.save(fwd(x_gt).cpu().detach(), "triton_joseph_proj_forward.torch")
xp.save(back(xp.asarray(dict_data['meas']).cuda()).cpu(), "triton_joseph_proj_backward.torch")

In [None]:
tmp_x = xp.ones_like(x_gt)
tmp_y = xp.ones_like(meas)

tmp_1 = (fwd(tmp_x)*tmp_y).sum()
tmp_2 = (back(tmp_y)*tmp_x).sum()
print(tmp_1, tmp_2)

In [None]:
import torch
x_parallel_proj = torch.load("parallel_proj_forward.torch")
x = torch.load("triton_joseph_proj_forward.torch")
import matplotlib.pyplot as plt
from ipywidgets import interact
print(x.shape)
def show_slices(slice_idx):
    plt.figure(figsize=(18, 6), dpi=80)
    plt.subplot(1,3,1)
    plt.imshow(x[:,:,slice_idx].cpu(), cmap='gray')
    plt.colorbar()
    plt.subplot(1,3,2)
    plt.imshow(x_parallel_proj[:,:,slice_idx].cpu(), cmap='gray')
    plt.colorbar()
    plt.subplot(1,3,3)
    plt.imshow((x[:,:,slice_idx].cpu()-x_parallel_proj[:,:,slice_idx].cpu()), cmap='gray')
    plt.colorbar()
    plt.show()
    
interact(show_slices, slice_idx=(0, x.shape[2]-1, 1))

In [None]:
import torch
x_parallel_proj = torch.load("parallel_proj_backward.torch")
x = torch.load("triton_joseph_proj_backward.torch")

import matplotlib.pyplot as plt
from ipywidgets import interact
def show_slices(slice_idx):
    plt.figure(figsize=(18, 6), dpi=80)
    plt.subplot(1,3,1)
    plt.imshow(x[:,slice_idx,:].cpu(), cmap='gray')
    plt.colorbar()
    plt.subplot(1,3,2)
    plt.imshow(x_parallel_proj[:,slice_idx,:].cpu(), cmap='gray')
    plt.colorbar()
    plt.subplot(1,3,3)
    plt.imshow((x[:,slice_idx,:].cpu()-x_parallel_proj[:,slice_idx,:].cpu()), cmap='gray')
    plt.colorbar()
    plt.show()
    
interact(show_slices, slice_idx=(0, x.shape[1]-1, 1))

In [None]:
import torch
parallel_proj_forward = torch.load("parallel_proj_forward.torch")
triton_joseph_proj_forward = torch.load("triton_joseph_proj_forward.torch")
parallel_proj_backward = torch.load("parallel_proj_backward.torch")
triton_joseph_proj_backward = torch.load("triton_joseph_proj_backward.torch")#
forward_difference = (parallel_proj_forward-triton_joseph_proj_forward).abs()
backward_difference = (parallel_proj_backward-triton_joseph_proj_backward).abs()
import matplotlib.pyplot as plt
from ipywidgets import interact

sino_slice = 0
img_slice = 0
inside=120

print(parallel_proj_forward[inside:-inside,..., sino_slice].shape)
plt.figure(figsize=(9,5))
plt.subplot(231)
plt.title("Parallel Proj")
plt.imshow(parallel_proj_forward[inside:-inside,..., sino_slice].cpu(), cmap="gray")
plt.ylabel("Forward-Projection")
plt.xticks([])
plt.yticks([])
plt.box(on=False)
# plt.axis("off")
plt.colorbar()
plt.subplot(232)
plt.title("Triton")
plt.imshow(triton_joseph_proj_forward[inside:-inside,...,sino_slice].cpu(), cmap="gray")
plt.axis("off")
plt.colorbar()
plt.subplot(233)
plt.title("Absolute Difference")
plt.imshow(forward_difference[inside:-inside,...,sino_slice].cpu(), cmap="gray")
plt.axis("off")
plt.colorbar()
plt.subplot(234)
plt.imshow(parallel_proj_backward[:,sino_slice].cpu(), cmap="gray")
# plt.axis("off")
plt.ylabel("Back-Projection")
plt.xticks([])
plt.yticks([])
plt.box(on=False)
plt.colorbar()
plt.subplot(235)
plt.imshow(triton_joseph_proj_backward[:,sino_slice].cpu(), cmap="gray")
plt.axis("off")
plt.colorbar()
plt.subplot(236)
plt.imshow(backward_difference[:,sino_slice].cpu(), cmap="gray")
plt.axis("off")
plt.colorbar()
plt.tight_layout()
plt.savefig("validation.png")
