In [None]:
from triton_joseph_proj import *
import numpy as np
import matplotlib.pyplot as plt
import time
import array_api_compat.torch as xp

torch.manual_seed(0)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True

dict_data = np.load("/home/user/triton/slices_8_lors_40.npy", allow_pickle=True).item()

for precision in ['float16','float32', 'float64']:
    if precision == 'float16':
        prec = xp.float16
    elif precision == 'float32':
        prec = xp.float32
    elif precision == 'float64':
        prec = xp.float64

    xstart = xp.asarray(dict_data['xstart']).cuda().to(prec)
    xend = xp.asarray(dict_data['xend']).cuda().to(prec)
    img_origin = xp.asarray(dict_data['img_origin']).cuda().to(prec)
    voxel_size = xp.asarray(dict_data['voxel_size']).cuda().to(prec)
    img_shape = dict_data['img_shape']
    x_gt = xp.asarray(dict_data['x']).cuda().to(prec)
    meas = xp.asarray(dict_data['meas']).cuda().to(prec) + 100
    print("xstart", xstart.dtype, "xend", xend.dtype, "img_origin", img_origin.dtype, "voxel_size", voxel_size.dtype, "img_shape", img_shape, "x_gt", x_gt.dtype, "meas", meas.dtype)
    fwd = lambda x: joseph3d_fwd_vec(xstart.reshape(-1,3), xend.reshape(-1,3), x, img_origin, voxel_size, img_shape).reshape(meas.shape)

    back = lambda y: joseph3d_back_vec(xstart.reshape(-1,3), xend.reshape(-1,3), y.ravel(), img_origin, voxel_size, img_shape)

    tmp = fwd(x_gt)
    torch.cuda.synchronize(x_gt.device)
    start = time.time()
    for i in range(10):
        tmp = fwd(x_gt)
    torch.cuda.synchronize(x_gt.device)
    end_t = time.time()
    forward_timing = (end_t - start) / 10

    sens_img = back(xp.ones_like(meas))
    x = xp.ones_like(sens_img)
    
    torch.cuda.synchronize(x_gt.device)
    start = time.time()
    for i in range(10):
        tmp = back(meas)
    torch.cuda.synchronize(x_gt.device)
    end_t = time.time()
    backward_timing = (end_t - start) / 10

    torch.cuda.synchronize(x_gt.device)
    start = time.time()
    for i in range(10):
        em_precond = x/sens_img.to(prec)
        ratio = meas/(fwd(x.to(prec))+100.)
        bp_ratio = back(ratio.to(prec))
        x = em_precond.to(prec)*bp_ratio.to(prec)
    torch.cuda.synchronize(x_gt.device)
    end_t = time.time()

    mlem_timing = (end_t - start) / 10
    _x = xp.ones_like(x_gt)
    y = xp.ones_like(meas)
    print("_x", _x.dtype, "y", y.dtype)

    inner_1 = xp.sum(fwd(_x).ravel().double()*y.ravel().double())
    inner_2 = xp.sum(back(y).ravel().double()*_x.ravel().double())
    mean_inner = (inner_1+inner_2)/2
    print(f"The difference is {xp.abs(inner_1-inner_2)/mean_inner*100:.3e}")
    print("forward_timing", forward_timing, "backward_timing", backward_timing, "mlem_timing", mlem_timing)
    print("forward and backward timing", forward_timing+backward_timing)
