In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import sys
from pathlib import Path
import matplotlib
matplotlib.use("WebAgg")
import matplotlib.pyplot as plt
GIT_ROOT = Path("../..").resolve()
SRC = GIT_ROOT / "src"
if not SRC in sys.path:
    sys.path.append(str(SRC))

In [4]:
from geometries import FlatFanBeamGeometry, DEVICE, HTC2022_GEOMETRY as geometry
from geometries.geometry_base import naive_sino_filling
from utils.tools import get_htc2022_train_phantoms, get_kits_train_phantoms
from models.fbps import AdaptiveFBP as AFBP
from models.FNOBPs.fnobp import FNO_BP
from models.modelbase import plot_model_progress
from statistics import mean

ar = 0.25 #angle ratio of full 360 deg scan
# PHANTOM_DATA = get_htc2022_train_phantoms()
PHANTOM_DATA = get_kits_train_phantoms()
geometry = FlatFanBeamGeometry(720, 560, 410.66, 543.74, 112, [-40,40, -40, 40], [256, 256])
SINO_DATA = geometry.project_forward(PHANTOM_DATA)

# model = AFBP(geometry)
model = FNO_BP(geometry, hidden_layers=[40,40], modes=geometry.projection_size//2)

dataset = TensorDataset(SINO_DATA, PHANTOM_DATA)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3, betas=(0.8, 0.95))
mse_fn = lambda diff : torch.mean(diff**2)
n_epochs = 100
for epoch in range(n_epochs):
    batch_losses, batch_sino_losses, batch_recon_losses = [], [], []
    for sino_batch, phantom_batch in dataloader:
        optimizer.zero_grad()

        start_ind = torch.randint(0, geometry.n_projections, (1,)).item()
        la_sinos, known_beta_bool = geometry.zero_cropp_sinos(sino_batch, ar=ar, start_ind=start_ind) #known_beta_bool is True at angles where sinogram is meassured and false otherwise
        la_sinos, known_region = geometry.reflect_fill_sinos(la_sinos, known_beta_bool)
        la_sinos = naive_sino_filling(la_sinos, (~known_region).sum(dim=-1) == 0)
        la_sinos = geometry.rotate_sinos(la_sinos, -start_ind) #FNO needs known angles to be in the same region all the time

        filtered = model.get_extrapolated_filtered_sinos(la_sinos)
        filtered = geometry.rotate_sinos(filtered, start_ind) #rotate back
        gt_filtered = geometry.inverse_fourier_transform(geometry.fourier_transform(sino_batch*geometry.jacobian_det)*geometry.ram_lak_filter())
        loss_sino_domain = mse_fn(gt_filtered-filtered)

        recons = F.relu(geometry.project_backward(filtered/2)) #sinogram covers 360deg  - double coverage
        loss_recon_domain = mse_fn(phantom_batch - recons)

        loss = loss_recon_domain + loss_sino_domain
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.cpu().item())
        batch_sino_losses.append(loss_sino_domain.cpu().item())
        batch_recon_losses.append(loss_recon_domain.cpu().item())
    
    print("Epoch:", epoch+1, "loss is:", mean(batch_losses), "sino loss is:", mean(batch_sino_losses), "recon loss is:", mean(batch_recon_losses))

Epoch: 1 loss is: 0.5134387958620554 sino loss is: 0.4741862479150295 recon loss is: 0.0392525479470259
Epoch: 2 loss is: 0.23167837030685828 sino loss is: 0.19200844463706015 recon loss is: 0.03966992566979812
Epoch: 3 loss is: 0.22019308199432208 sino loss is: 0.18345234975218772 recon loss is: 0.03674073224213435
Epoch: 4 loss is: 0.20225187926538782 sino loss is: 0.16950221705436708 recon loss is: 0.03274966221102075
Epoch: 5 loss is: 0.19319195611834883 sino loss is: 0.1608447540998459 recon loss is: 0.03234720201850294
Epoch: 6 loss is: 0.1870912845536939 sino loss is: 0.1557417548596859 recon loss is: 0.031349529694007984
Epoch: 7 loss is: 0.18549428325208434 sino loss is: 0.15452100497484206 recon loss is: 0.030973278277242274
Epoch: 8 loss is: 0.18453098201029333 sino loss is: 0.15355392330884934 recon loss is: 0.030977058701443995
Epoch: 9 loss is: 0.18189671223567305 sino loss is: 0.15130121728777884 recon loss is: 0.03059549494789419
Epoch: 10 loss is: 0.17907692619618154 s

KeyboardInterrupt: 

In [11]:
from models.fbps import FBP
fbp = FBP(model.geometry)
#Clear previous plots
for i in plt.get_fignums():
    plt.figure(i)
    plt.close()

zero_cropped_sinos, known_region = geometry.reflect_fill_sinos(*geometry.zero_cropp_sinos(SINO_DATA[:5], ar, 0))
zero_cropped_sinos = naive_sino_filling(zero_cropped_sinos, (~known_region).sum(dim=-1) == 0)
disp_ind = 2
plot_model_progress(model, zero_cropped_sinos, SINO_DATA[:5], PHANTOM_DATA[:5], disp_ind=disp_ind)
plot_model_progress(fbp, zero_cropped_sinos, SINO_DATA[:5], PHANTOM_DATA[:5], disp_ind=disp_ind)

plt.show()

FNO_BP
sinogram mse: tensor(10.6201, device='cuda:0')
filterd sinogram mse:  tensor(0.1603, device='cuda:0')
reconstruction mse:  tensor(0.0342, device='cuda:0', dtype=torch.float64)
FBP
sinogram mse: tensor(10.6201, device='cuda:0')
filterd sinogram mse:  tensor(0.2048, device='cuda:0')
reconstruction mse:  tensor(0.0302, device='cuda:0', dtype=torch.float64)
Press Ctrl+C to stop WebAgg server


RuntimeError: This event loop is already running

In [7]:
#Full angle fbp recons mse
recons = geometry.fbp_reconstruct(SINO_DATA)
print(torch.mean((recons-PHANTOM_DATA)**2).item())

0.005315318074442698


In [31]:
#Save model
modelname = "fno_bp_ar0.25"
from models.modelbase import save_model_checkpoint
save_path = GIT_ROOT / "data" / "models" / (modelname + ".pt")
save_model_checkpoint(model, optimizer, loss, ar, save_path)
print("model saved to", save_path)

model saved to /home/emastr/deep-limited-angle/KEX---CT-reconstruction/data/models/fno_bp_ar0.25.pt
