In [1]:
import torch

from diffdrr import get_device, load_example_ct, delete_tensor
from diffdrr.utils.camera import Detector
from diffdrr.projectors.siddon import Siddon
device = get_device("cuda")

In [2]:
height = width = 50
delx = dely = 5e-2
sdr = 200

In [3]:
batch_size = 2

translations = torch.randn(batch_size, 3, device=device)
rotations = torch.randn(batch_size, 3, device=device)
rotations = torch.remainder(rotations, 2*torch.pi)
rotations.shape

torch.Size([2, 3])

In [4]:
detector = Detector(height, width, delx, dely, "cuda")
source, target = detector.make_xrays(sdr, rotations, translations)
source.shape, target.shape

(torch.Size([2, 1, 1, 3]), torch.Size([2, 50, 50, 3]))

In [5]:
volume, spacing = load_example_ct()
siddon = Siddon(volume, spacing, device)

spacing = torch.tensor(spacing, device=device)
dims = torch.tensor(volume.shape, device=device) + 1
# siddon.raytrace(source, rays)

In [6]:
sdd = target - source
planes = torch.zeros(3, device=device)
alpha0 = (planes * spacing - source) / sdd
planes = dims - 1
alpha1 = (planes * spacing - source) / sdd
alphas = torch.stack([alpha0, alpha1])
alphamin = alphas.min(dim=0).values.max(dim=-1).values.unsqueeze(1)
alphamax = alphas.max(dim=0).values.min(dim=-1).values.unsqueeze(1)

alphamin.shape, alphamax.shape

(torch.Size([2, 1, 50, 50]), torch.Size([2, 1, 50, 50]))

In [10]:
# Get the CT sizing and spacing parameters
dx, dy, dz = spacing
nx, ny, nz = dims
# self.maxidx = ((nx - 1) * (ny - 1) * (nz - 1)).int().item() - 1

# Get the alpha at each plane intersection
sx, sy, sz = source[:, 0, 0, 0], source[:, 0, 0, 1], source[:, 0, 0, 2]
alphax = torch.arange(nx, dtype=torch.float32, device=device).unsqueeze(1)
alphay = torch.arange(ny, dtype=torch.float32, device=device).unsqueeze(1)
alphaz = torch.arange(nz, dtype=torch.float32, device=device).unsqueeze(1)
alphax = (alphax * dx - sx).unsqueeze(-1).unsqueeze(-1).permute(1, 0, 2, 3)
alphay = (alphay * dy - sy).unsqueeze(-1).unsqueeze(-1).permute(1, 0, 2, 3)
alphaz = (alphaz * dz - sz).unsqueeze(-1).unsqueeze(-1).permute(1, 0, 2, 3)

sdd = (target - source).unsqueeze(1)  # source-to-detector distance
alphax = alphax / sdd[:, :, :, :, 0]
alphay = alphay / sdd[:, :, :, :, 1]
alphaz = alphaz / sdd[:, :, :, :, 2]
alphas = torch.cat([alphax, alphay, alphaz], dim=1)
alphas.shape

torch.Size([2, 1160, 50, 50])

In [8]:
good_idxs = torch.logical_and(alphas >= alphamin, alphas <= alphamax)
alphas[~good_idxs] = torch.nan
alphas = torch.sort(alphas, dim=1).values
# alphas = alphas[~alphas.isnan().all(dim=-1).all(dim=-1)].unsqueeze(0)
alphas.shape

torch.Size([2, 1160, 50, 50])