In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import sys
from pathlib import Path
import matplotlib
matplotlib.use("WebAgg")
import matplotlib.pyplot as plt
GIT_ROOT = Path("../..").resolve()
SRC = GIT_ROOT / "src"
if not SRC in sys.path:
    sys.path.append(str(SRC))

In [9]:
from utils.data import get_htc2022_train_phantoms
from sklearn.model_selection import train_test_split

HTC_PHANTOMS = get_htc2022_train_phantoms()
print(HTC_PHANTOMS.shape)
HTC_TRAIN_PHANTOMS, VALIDATION_PHANTOMS = train_test_split(HTC_PHANTOMS, test_size=3)

print(HTC_TRAIN_PHANTOMS.shape)
print(VALIDATION_PHANTOMS.shape)


torch.Size([5, 512, 512])
torch.Size([2, 512, 512])
torch.Size([3, 512, 512])


In [5]:
from utils.data import get_htc2022_train_phantoms, get_kits_train_phantoms, get_synthetic_htc_phantoms
from utils.polynomials import Legendre, Chebyshev
from geometries import FlatFanBeamGeometry, DEVICE, HTC2022_GEOMETRY, ParallelGeometry
from geometries.geometry_base import naive_sino_filling, mark_cyclic
from models.fbps import AdaptiveFBP as AFBP
from models.FNOBPs.fnobp import FNO_BP
from models.SerieBPs.series_bp1 import Series_BP
from models.modelbase import plot_model_progress
from statistics import mean

ar = 0.25 #angle ratio of full 360 deg scan
PHANTOM_DATA = torch.concat([HTC_TRAIN_PHANTOMS, get_synthetic_htc_phantoms()])
geometry = HTC2022_GEOMETRY
# PHANTOM_DATA = get_kits_train_phantoms()
# geometry = FlatFanBeamGeometry(720, 560, 410.66, 543.74, 112, [-40,40, -40, 40], [256, 256])
# geometry = FlatFanBeamGeometry(1800, 300, 1.5, 3.0, 4.0, [-1,1,-1,1], [256, 256])
# geometry = ParallelGeometry(1800, 300, [-1,1,-1,1], [256, 256])
N_known_angles = int(geometry.n_projections*ar)
N_angles_out = int(geometry.n_projections*0.6) #can be 0.5 if parallel beam
SINO_DATA = geometry.project_forward(PHANTOM_DATA)
print(SINO_DATA.dtype, SINO_DATA.device, SINO_DATA.shape)

model = Series_BP(geometry, ar, 120, 60, Legendre.key)

dataset = TensorDataset(SINO_DATA, PHANTOM_DATA)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)
mse_fn = lambda diff : torch.mean(diff**2)
n_epochs = 300
for epoch in range(n_epochs):
    batch_losses, batch_sino_losses, batch_recon_losses = [], [], []
    for sino_batch, phantom_batch in dataloader:
        optimizer.zero_grad()

        start_ind = 0 #torch.randint(0, geometry.n_projections, (1,)).item()
        known_beta_bool = torch.zeros(geometry.n_projections, device=DEVICE, dtype=bool)
        out_beta_bool = known_beta_bool.clone()
        mark_cyclic(known_beta_bool, start_ind, (start_ind+N_known_angles)%geometry.n_projections)#known_beta_bool is True at angles where sinogram is meassured and false otherwise
        mark_cyclic(out_beta_bool, start_ind, (start_ind+N_angles_out)%geometry.n_projections)
        la_sinos = sino_batch * 0 #limited angle sinograms
        la_sinos[:, known_beta_bool] = sino_batch[:, known_beta_bool]

        filtered = model.get_extrapolated_filtered_sinos(la_sinos, known_beta_bool)
        gt_filtered = geometry.inverse_fourier_transform(geometry.fourier_transform(sino_batch*geometry.jacobian_det)*geometry.ram_lak_filter())
        loss_sino_domain = mse_fn(gt_filtered-filtered)

        recons = F.relu(geometry.project_backward(filtered/2)) #sinogram covers 360deg  - double coverage
        loss_recon_domain = mse_fn(phantom_batch - recons)

        loss = loss_recon_domain + loss_sino_domain*0.1
        loss.backward()
        optimizer.step()

        batch_losses.append(loss.cpu().item())
        batch_sino_losses.append(loss_sino_domain.cpu().item())
        batch_recon_losses.append(loss_recon_domain.cpu().item())
    
    print("Epoch:", epoch+1, "loss is:", mean(batch_losses), "sino loss is:", mean(batch_sino_losses), "recon loss is:", mean(batch_recon_losses))

torch.float32 cuda:0 torch.Size([86, 720, 560])


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch: 1 loss is: 0.5024940176691732 sino loss is: 0.6405034254897725 recon loss is: 0.4384436744428704
Epoch: 2 loss is: 0.48824728201865575 sino loss is: 0.455335332588716 recon loss is: 0.44271374821792375
Epoch: 3 loss is: 0.48581546415689886 sino loss is: 0.4451517273079265 recon loss is: 0.44130029027465284
Epoch: 4 loss is: 0.4852509280078704 sino loss is: 0.4382050335407257 recon loss is: 0.4414304239087397
Epoch: 5 loss is: 0.48470636900000036 sino loss is: 0.43378400802612305 recon loss is: 0.44132796779099276
Epoch: 6 loss is: 0.48504277311333227 sino loss is: 0.428209207274697 recon loss is: 0.44222185211493237
Epoch: 7 loss is: 0.483845809793083 sino loss is: 0.42493477734652435 recon loss is: 0.44135233165203525
Epoch: 8 loss is: 0.4831060509856447 sino loss is: 0.42217493057250977 recon loss is: 0.44088855758973095
Epoch: 9 loss is: 0.4823512588544299 sino loss is: 0.4192949641834606 recon loss is: 0.4404217615555607
Epoch: 10 loss is: 0.48216581865824876 sino loss is: 0

KeyboardInterrupt: 

In [10]:
from models.fbps import FBP
fbp = FBP(model.geometry)
#Clear previous plots
for i in plt.get_fignums():
    plt.figure(i)
    plt.close()

VALIDATION_SINOS = geometry.project_forward(VALIDATION_PHANTOMS)
zero_cropped_validation_sinos, known_region = geometry.reflect_fill_sinos(*geometry.zero_cropp_sinos(VALIDATION_SINOS, ar, 0))
angles = torch.zeros(geometry.n_projections, dtype=bool, device=DEVICE)
known_angles = mark_cyclic(angles.clone(), 0, N_known_angles)
out_angles = mark_cyclic(angles.clone(), 0, N_angles_out)
# zero_cropped_sinos = naive_sino_filling(zero_cropped_sinos, (~known_region).sum(dim=-1) == 0)
disp_ind = 0
plot_model_progress(model, VALIDATION_SINOS, known_angles, out_angles, VALIDATION_PHANTOMS, disp_ind=disp_ind)
plot_model_progress(fbp, VALIDATION_SINOS, known_angles, out_angles, VALIDATION_PHANTOMS, disp_ind=disp_ind)

plt.show()

Series_BP
sinogram mse: tensor(2913.8262, device='cuda:0')
filterd sinogram mse:  tensor(2473.3689, device='cuda:0')
reconstruction mse:  tensor(0.9016, device='cuda:0', dtype=torch.float64)
FBP
sinogram mse: tensor(1245.4938, device='cuda:0')
filterd sinogram mse:  tensor(2484.3323, device='cuda:0')
reconstruction mse:  tensor(0.1939, device='cuda:0', dtype=torch.float64)


In [None]:
#Full angle fbp recons mse
recons = geometry.fbp_reconstruct(SINO_DATA)
print(torch.mean((recons-PHANTOM_DATA)**2).item())

In [9]:
#Save model
modelname = "fno_bp_fanbeamkits_ar0.25"
from models.modelbase import save_model_checkpoint
save_path = GIT_ROOT / "data" / "models" / (modelname + ".pt")
save_model_checkpoint(model, optimizer, loss, ar, save_path)
print("model saved to", save_path)

model saved to /home/emastr/deep-limited-angle/KEX---CT-reconstruction/data/models/fno_bp_fanbeamkits_ar0.25.pt
