In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import sys
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
GIT_ROOT = Path("../..").resolve()
SRC = GIT_ROOT / "src"
print(SRC)
if not SRC in sys.path:
    sys.path.append(str(SRC))

  from .autonotebook import tqdm as notebook_tqdm


/home/ubuntu/KEX---CT-reconstruction/src


In [2]:
from utils.data import get_htc2022_train_phantoms, get_htc_trainval_phantoms
from geometries import HTC2022_GEOMETRY

PHANTOMS, VALIDATION_PHANTOMS = get_htc_trainval_phantoms()
geometry = HTC2022_GEOMETRY
ar = 0.25
SINOS = geometry.project_forward(PHANTOMS)

print(PHANTOMS.shape)
print(SINOS.shape)

torch.Size([584, 512, 512])
torch.Size([584, 720, 560])


In [4]:
from geometries.extrapolation import extrapolate_fixpoint, extrapolate_cgm
from utils.polynomials import Legendre, Chebyshev
from utils.tools import MSE

M, K = 50, 50
la_sinos, knwon_angles = geometry.zero_cropp_sinos(SINOS, ar, 0)
exp = extrapolate_cgm(la_sinos, ar, geometry, M, K, PolynomialFamily=Legendre, tol=1.0)
la_sinos[:, ~knwon_angles] = exp[:, ~knwon_angles]

print(la_sinos.shape)
print("Extrapolation MSE:", MSE(la_sinos, SINOS))

k: 1 res mse: tensor(32.8191, device='cuda:0')
k: 2 res mse: tensor(5.0299, device='cuda:0')
k: 3 res mse: tensor(5.5388, device='cuda:0')
k: 4 res mse: tensor(3.5794, device='cuda:0')
k: 5 res mse: tensor(2.7306, device='cuda:0')
k: 6 res mse: tensor(2.1879, device='cuda:0')
k: 7 res mse: tensor(2.2033, device='cuda:0')
k: 8 res mse: tensor(1.9402, device='cuda:0')
k: 9 res mse: tensor(1.8013, device='cuda:0')
k: 10 res mse: tensor(1.5963, device='cuda:0')
k: 11 res mse: tensor(1.3591, device='cuda:0')
k: 12 res mse: tensor(1.1652, device='cuda:0')
k: 13 res mse: tensor(1.1747, device='cuda:0')
k: 14 res mse: tensor(1.2120, device='cuda:0')
k: 15 res mse: tensor(1.2495, device='cuda:0')
k: 16 res mse: tensor(1.2651, device='cuda:0')
k: 17 res mse: tensor(1.2533, device='cuda:0')
k: 18 res mse: tensor(1.2351, device='cuda:0')
k: 19 res mse: tensor(1.2650, device='cuda:0')
k: 20 res mse: tensor(1.2519, device='cuda:0')
k: 21 res mse: tensor(1.2074, device='cuda:0')
k: 22 res mse: tensor

In [6]:
torch.save(la_sinos, GIT_ROOT / "data/exp_cgm_training_sinos.pt")
# la_sinos = torch.load(GIT_ROOT / "data/exp_training_sinos.pt")
# print(la_sinos.shape)

In [7]:
import gc
torch.cuda.empty_cache()
gc.collect()

1745

In [8]:
import torch
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    3283 MB |   11458 MB |  446553 MB |  443270 MB |
|       from large pool |    3283 MB |   11457 MB |  444773 MB |  441489 MB |
|       from small pool |       0 MB |       1 MB |    1780 MB |    1780 MB |
|---------------------------------------------------------------------------|
| Active memory         |    3283 MB |   11458 MB |  446553 MB |  443270 MB |
|       from large pool |    3283 MB |   11457 MB |  444773 MB |  441489 MB |
|       from small pool |       0 MB |       1 MB |    1780 MB |    1780 MB |
|---------------------------------------------------------------

In [12]:
from models.FNOBPs.fnobp import FNO_BP
from statistics import mean
from utils.tools import MSE
from geometries import HTC2022_GEOMETRY
import gc

from models.fbps import AdaptiveFBP
geometry = HTC2022_GEOMETRY
ar = 0.25 #angle ratio of full 360 deg scan

# model = FNO_BP(geometry, [40, 40])
model = AdaptiveFBP(geometry)

dataset = TensorDataset(la_sinos, PHANTOMS)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-6)
n_epochs = 300
for epoch in range(n_epochs):
    batch_losses = []
    for la_sino_batch, phantom_batch in dataloader:
        optimizer.zero_grad()

        recons = model.forward(la_sinos, knwon_angles)
        
        loss = MSE(recons, phantom_batch)
        loss.backward()

        optimizer.step()
        batch_losses.append(loss.cpu().item())
    
    print("Epoch:", epoch+1, "loss is:", mean(batch_losses))

AttributeError: 'NoneType' object has no attribute 'shape'

In [10]:
from models.fbps import FBP
from models.modelbase import evaluate_batches
fbp = FBP(model.geometry)
#Clear previous plots
for i in plt.get_fignums():
    plt.figure(i)
    plt.close()

VALIDATION_SINOS = geometry.project_forward(VALIDATION_PHANTOMS)
validation_la, known_angles_validation = geometry.zero_cropp_sinos(VALIDATION_SINOS, ar, 0)
exp_validation = extrapolate_fixpoint(validation_la, known_angles_validation, geometry, M, K, PolynomialFamily=Legendre)
validation_la[:, ~known_angles_validation] = exp_validation[:, ~known_angles_validation]

disp_ind = 2
recons = model(validation_la)
evaluate_batches(recons, VALIDATION_PHANTOMS, disp_ind, title="FNO recons")

Series_BP
sinogram mse: tensor(2913.8262, device='cuda:0')
filterd sinogram mse:  tensor(2473.3689, device='cuda:0')
reconstruction mse:  tensor(0.9016, device='cuda:0', dtype=torch.float64)
FBP
sinogram mse: tensor(1245.4938, device='cuda:0')
filterd sinogram mse:  tensor(2484.3323, device='cuda:0')
reconstruction mse:  tensor(0.1939, device='cuda:0', dtype=torch.float64)


In [9]:
#Save model
modelname = "fno_bp"
from models.modelbase import save_model_checkpoint
save_path = GIT_ROOT / "data" / "models" / (modelname + ".pt")
save_model_checkpoint(model, optimizer, loss, ar, save_path)
print("model saved to", save_path)

model saved to /home/emastr/deep-limited-angle/KEX---CT-reconstruction/data/models/fno_bp_fanbeamkits_ar0.25.pt
