In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import sys
from pathlib import Path
import matplotlib
matplotlib.use("WebAgg")
import matplotlib.pyplot as plt
GIT_ROOT = Path("../..").resolve()
SRC = GIT_ROOT / "src"
if not SRC in sys.path:
    sys.path.append(str(SRC))

In [2]:
from utils.data import get_htc2022_train_phantoms, get_kits_train_phantoms, get_htclike_train_phantoms
from utils.polynomials import Legendre, Chebyshev
from geometries import FlatFanBeamGeometry, DEVICE, HTC2022_GEOMETRY, ParallelGeometry
from geometries.geometry_base import naive_sino_filling, mark_cyclic
from models.fbps import AdaptiveFBP as AFBP
from models.FNOBPs.fnobp import FNO_BP
from models.SerieBPs.series_bp1 import Series_BP
from models.modelbase import plot_model_progress
from statistics import mean

ar = 0.25 #angle ratio of full 360 deg scan
PHANTOM_DATA = torch.concat([get_htc2022_train_phantoms(), get_htclike_train_phantoms()])
geometry = HTC2022_GEOMETRY
# PHANTOM_DATA = get_kits_train_phantoms()
# geometry = FlatFanBeamGeometry(720, 560, 410.66, 543.74, 112, [-40,40, -40, 40], [256, 256])
# geometry = FlatFanBeamGeometry(1800, 300, 1.5, 3.0, 4.0, [-1,1,-1,1], [256, 256])
# geometry = ParallelGeometry(1800, 300, [-1,1,-1,1], [256, 256])
N_known_angles = int(geometry.n_projections*ar)
N_angles_out = int(geometry.n_projections*0.6) #can be 0.5 if parallel beam
SINO_DATA = geometry.project_forward(PHANTOM_DATA)
print(SINO_DATA.dtype, SINO_DATA.device, SINO_DATA.shape)

# model = AFBP(geometry)
# model = FNO_BP(geometry, hidden_layers=[40,40], n_known_angles=N_known_angles, n_angles_out=N_angles_out)
model = Series_BP(geometry, ar, 120, 60, Legendre.key)


dataset = TensorDataset(SINO_DATA, PHANTOM_DATA)
dataloader = DataLoader(dataset, batch_size=6, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
mse_fn = lambda diff : torch.mean(diff**2)
n_epochs = 100
for epoch in range(n_epochs):
    batch_losses, batch_sino_losses, batch_recon_losses = [], [], []
    for sino_batch, phantom_batch in dataloader:
        optimizer.zero_grad()

        start_ind = 0 #torch.randint(0, geometry.n_projections, (1,)).item()
        known_beta_bool = torch.zeros(geometry.n_projections, device=DEVICE, dtype=bool)
        out_beta_bool = known_beta_bool.clone()
        mark_cyclic(known_beta_bool, start_ind, (start_ind+N_known_angles)%geometry.n_projections)#known_beta_bool is True at angles where sinogram is meassured and false otherwise
        mark_cyclic(out_beta_bool, start_ind, (start_ind+N_angles_out)%geometry.n_projections)
        la_sinos = sino_batch * 0 #limited angle sinograms
        la_sinos[:, known_beta_bool] = sino_batch[:, known_beta_bool]

        filtered = model.get_extrapolated_filtered_sinos(la_sinos, known_beta_bool, out_beta_bool)
        gt_filtered = geometry.inverse_fourier_transform(geometry.fourier_transform(sino_batch*geometry.jacobian_det)*geometry.ram_lak_filter())
        loss_sino_domain = mse_fn(gt_filtered-filtered)

        recons = F.relu(geometry.project_backward(filtered/2)) #sinogram covers 360deg  - double coverage
        loss_recon_domain = mse_fn(phantom_batch - recons)

        loss = loss_recon_domain + loss_sino_domain*0.1
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.cpu().item())
        batch_sino_losses.append(loss_sino_domain.cpu().item())
        batch_recon_losses.append(loss_recon_domain.cpu().item())
    
    print("Epoch:", epoch+1, "loss is:", mean(batch_losses), "sino loss is:", mean(batch_sino_losses), "recon loss is:", mean(batch_recon_losses))

torch.float32 cuda:0 torch.Size([30, 720, 560])
Epoch: 1 loss is: 0.4657011207818689 sino loss is: 2.1416247606277468 recon loss is: 0.2515386429309549
Epoch: 2 loss is: 0.1593964686810355 sino loss is: 0.6558303952217102 recon loss is: 0.09381342662566708
Epoch: 3 loss is: 0.13253633472443388 sino loss is: 0.5159070253372192 recon loss is: 0.08094563055158423
Epoch: 4 loss is: 0.1293597278490482 sino loss is: 0.44957255125045775 recon loss is: 0.08440247302202564
Epoch: 5 loss is: 0.14283618812799875 sino loss is: 0.5406239271163941 recon loss is: 0.08877379407525483
Epoch: 6 loss is: 0.1274639356436881 sino loss is: 0.4717994570732117 recon loss is: 0.08028398859526245
Epoch: 7 loss is: 0.11829293390982669 sino loss is: 0.43819177746772764 recon loss is: 0.07447375452392618
Epoch: 8 loss is: 0.11057496912561542 sino loss is: 0.4163780093193054 recon loss is: 0.06893716700159198
Epoch: 9 loss is: 0.10588835503070378 sino loss is: 0.4045316517353058 recon loss is: 0.06543518963365579
E

In [3]:
from models.fbps import FBP
fbp = FBP(model.geometry)
#Clear previous plots
for i in plt.get_fignums():
    plt.figure(i)
    plt.close()

zero_cropped_sinos, known_region = geometry.reflect_fill_sinos(*geometry.zero_cropp_sinos(SINO_DATA[:5], ar, 0))
angles = torch.zeros(geometry.n_projections, dtype=bool, device=DEVICE)
known_angles = mark_cyclic(angles.clone(), 0, N_known_angles)
out_angles = mark_cyclic(angles.clone(), 0, N_angles_out)
# zero_cropped_sinos = naive_sino_filling(zero_cropped_sinos, (~known_region).sum(dim=-1) == 0)
disp_ind = 2
plot_model_progress(model, SINO_DATA[:5], known_angles, out_angles, PHANTOM_DATA[:5], disp_ind=disp_ind)
plot_model_progress(fbp, SINO_DATA[:5], known_angles, out_angles, PHANTOM_DATA[:5], disp_ind=disp_ind)

plt.show()

FNO_BP
sinogram mse: tensor(1652.9601, device='cuda:0')
filterd sinogram mse:  tensor(2180.3899, device='cuda:0')
reconstruction mse:  tensor(0.0402, device='cuda:0', dtype=torch.float64)
FBP
sinogram mse: tensor(1099.7948, device='cuda:0')
filterd sinogram mse:  tensor(2194.1360, device='cuda:0')
reconstruction mse:  tensor(0.1856, device='cuda:0', dtype=torch.float64)
Press Ctrl+C to stop WebAgg server


RuntimeError: This event loop is already running

In [None]:
#Full angle fbp recons mse
recons = geometry.fbp_reconstruct(SINO_DATA)
print(torch.mean((recons-PHANTOM_DATA)**2).item())

In [9]:
#Save model
modelname = "fno_bp_fanbeamkits_ar0.25"
from models.modelbase import save_model_checkpoint
save_path = GIT_ROOT / "data" / "models" / (modelname + ".pt")
save_model_checkpoint(model, optimizer, loss, ar, save_path)
print("model saved to", save_path)

model saved to /home/emastr/deep-limited-angle/KEX---CT-reconstruction/data/models/fno_bp_fanbeamkits_ar0.25.pt
