In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import sys
from pathlib import Path
import matplotlib
matplotlib.use("WebAgg")
import matplotlib.pyplot as plt
GIT_ROOT = Path("../..").resolve()
SRC = GIT_ROOT / "src"
if not SRC in sys.path:
    sys.path.append(str(SRC))

In [3]:
from utils.data import get_htc2022_train_phantoms, get_kits_train_phantoms, get_htclike_train_phantoms
from utils.polynomials import Legendre, Chebyshev
from geometries import FlatFanBeamGeometry, DEVICE, HTC2022_GEOMETRY, ParallelGeometry
from geometries.geometry_base import naive_sino_filling, mark_cyclic
from models.fbps import AdaptiveFBP as AFBP
from models.FNOBPs.fnobp import FNO_BP
from models.SerieBPs.series_bp1 import Series_BP
from models.modelbase import plot_model_progress
from statistics import mean

ar = 0.25 #angle ratio of full 360 deg scan
PHANTOM_DATA = torch.concat([get_htc2022_train_phantoms(), get_htclike_train_phantoms()])
geometry = HTC2022_GEOMETRY
# PHANTOM_DATA = get_kits_train_phantoms()
# geometry = FlatFanBeamGeometry(720, 560, 410.66, 543.74, 112, [-40,40, -40, 40], [256, 256])
# geometry = FlatFanBeamGeometry(1800, 300, 1.5, 3.0, 4.0, [-1,1,-1,1], [256, 256])
# geometry = ParallelGeometry(1800, 300, [-1,1,-1,1], [256, 256])
N_known_angles = int(geometry.n_projections*ar)
N_angles_out = int(geometry.n_projections*0.6) #can be 0.5 if parallel beam
SINO_DATA = geometry.project_forward(PHANTOM_DATA)
print(SINO_DATA.dtype, SINO_DATA.device, SINO_DATA.shape)

# model = AFBP(geometry)
# model = FNO_BP(geometry, hidden_layers=[40,40], n_known_angles=N_known_angles, n_angles_out=N_angles_out)
model = Series_BP(geometry, ar, 120, 60, Legendre.key)


dataset = TensorDataset(SINO_DATA, PHANTOM_DATA)
dataloader = DataLoader(dataset, batch_size=6, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)
mse_fn = lambda diff : torch.mean(diff**2)
n_epochs = 800
for epoch in range(n_epochs):
    batch_losses, batch_sino_losses, batch_recon_losses = [], [], []
    for sino_batch, phantom_batch in dataloader:
        optimizer.zero_grad()

        start_ind = 0 #torch.randint(0, geometry.n_projections, (1,)).item()
        known_beta_bool = torch.zeros(geometry.n_projections, device=DEVICE, dtype=bool)
        out_beta_bool = known_beta_bool.clone()
        mark_cyclic(known_beta_bool, start_ind, (start_ind+N_known_angles)%geometry.n_projections)#known_beta_bool is True at angles where sinogram is meassured and false otherwise
        mark_cyclic(out_beta_bool, start_ind, (start_ind+N_angles_out)%geometry.n_projections)
        la_sinos = sino_batch * 0 #limited angle sinograms
        la_sinos[:, known_beta_bool] = sino_batch[:, known_beta_bool]

        filtered = model.get_extrapolated_filtered_sinos(la_sinos, known_beta_bool, out_beta_bool)
        gt_filtered = geometry.inverse_fourier_transform(geometry.fourier_transform(sino_batch*geometry.jacobian_det)*geometry.ram_lak_filter())
        loss_sino_domain = mse_fn(gt_filtered-filtered)

        recons = F.relu(geometry.project_backward(filtered/2)) #sinogram covers 360deg  - double coverage
        loss_recon_domain = mse_fn(phantom_batch - recons)

        loss = loss_recon_domain + loss_sino_domain*0.1
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.cpu().item())
        batch_sino_losses.append(loss_sino_domain.cpu().item())
        batch_recon_losses.append(loss_recon_domain.cpu().item())
    
    print("Epoch:", epoch+1, "loss is:", mean(batch_losses), "sino loss is:", mean(batch_sino_losses), "recon loss is:", mean(batch_recon_losses))

torch.float32 cuda:0 torch.Size([30, 720, 560])
Epoch: 1 loss is: 0.7661406057924995 sino loss is: 2.268755567073822 recon loss is: 0.5392650495321522
Epoch: 2 loss is: 0.6174284484641824 sino loss is: 0.5585854411125183 recon loss is: 0.5615699027138029
Epoch: 3 loss is: 0.6158038087022187 sino loss is: 0.5417538046836853 recon loss is: 0.5616284273397806
Epoch: 4 loss is: 0.6146564530127546 sino loss is: 0.5265403389930725 recon loss is: 0.5620024191134474
Epoch: 5 loss is: 0.6140012047689866 sino loss is: 0.5207212805747986 recon loss is: 0.5619290761154603
Epoch: 6 loss is: 0.6135230691885885 sino loss is: 0.5191152572631836 recon loss is: 0.5616115422701772
Epoch: 7 loss is: 0.6131209582774032 sino loss is: 0.5171555995941162 recon loss is: 0.5614053972749103
Epoch: 8 loss is: 0.6127577460983262 sino loss is: 0.5121966660022735 recon loss is: 0.5615380793490873
Epoch: 9 loss is: 0.6124168724460384 sino loss is: 0.5082147896289826 recon loss is: 0.5615953922910472
Epoch: 10 loss is

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [4]:
from models.fbps import FBP
fbp = FBP(model.geometry)
#Clear previous plots
for i in plt.get_fignums():
    plt.figure(i)
    plt.close()

zero_cropped_sinos, known_region = geometry.reflect_fill_sinos(*geometry.zero_cropp_sinos(SINO_DATA[:5], ar, 0))
angles = torch.zeros(geometry.n_projections, dtype=bool, device=DEVICE)
known_angles = mark_cyclic(angles.clone(), 0, N_known_angles)
out_angles = mark_cyclic(angles.clone(), 0, N_angles_out)
# zero_cropped_sinos = naive_sino_filling(zero_cropped_sinos, (~known_region).sum(dim=-1) == 0)
disp_ind = 2
plot_model_progress(model, SINO_DATA[:5], known_angles, out_angles, PHANTOM_DATA[:5], disp_ind=disp_ind)
plot_model_progress(fbp, SINO_DATA[:5], known_angles, out_angles, PHANTOM_DATA[:5], disp_ind=disp_ind)

plt.show()

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
#Full angle fbp recons mse
recons = geometry.fbp_reconstruct(SINO_DATA)
print(torch.mean((recons-PHANTOM_DATA)**2).item())

In [9]:
#Save model
modelname = "fno_bp_fanbeamkits_ar0.25"
from models.modelbase import save_model_checkpoint
save_path = GIT_ROOT / "data" / "models" / (modelname + ".pt")
save_model_checkpoint(model, optimizer, loss, ar, save_path)
print("model saved to", save_path)

model saved to /home/emastr/deep-limited-angle/KEX---CT-reconstruction/data/models/fno_bp_fanbeamkits_ar0.25.pt
