A script to compare the performances of NNFBP and MSDNET to reconstruct random spheres with 9 projections.

In [1]:
### Imports ###
from nntomo.utilities import get_MSE_loss
from nntomo.network import nnfbp_training, msdnet_training
from nntomo.nnfbp import DatasetNNFBP
from nntomo.msdnet import DatasetMSDNET
from nntomo.projection_stack import ProjectionStack
from nntomo.volume import Volume

%load_ext autoreload
%autoreload 2
%matplotlib widget

In [2]:
# Generation of the training and validation datasets

training_spheres = Volume.random_spheres(40, shape=512)
training_projections = ProjectionStack.from_volume(training_spheres, 9, 'full')
training_dataset_nnfbp = DatasetNNFBP(training_projections, training_spheres)
training_dataset_msdnet = DatasetMSDNET(training_projections, training_spheres)

validation_spheres = Volume.random_spheres(40, shape=512)
validation_projections = ProjectionStack.from_volume(validation_spheres, 9, 'full')
validation_dataset_nnfbp = DatasetNNFBP(validation_projections, validation_spheres)
validation_dataset_msdnet = DatasetMSDNET(validation_projections, validation_spheres)

Generation of the spheres: [████████████████████████████████████████████████████████████] 40/40 Est wait 00:0.00

Generation of the spheres: [████████████████████████████████████████████████████████████] 40/40 Est wait 00:0.00



In [3]:
# Training

nnfbp = nnfbp_training(training_dataset_nnfbp, validation_dataset_nnfbp, 8, custom_id="nnfbp_network_comparison")
msdnet = msdnet_training(training_dataset_msdnet, validation_dataset_msdnet, depth=100, batch_size=2, max_epoch=50, learning_rate = 1e-3, custom_id="msdnet_network_comparison")

Epoch 190 (n=25) | Best avg MSELoss(): 0.002011 | End of training                                 
Epoch 36 (n=25) | Best avg MSELoss(): 0.012045 | End of training                                 


In [6]:
# Reconstructions

test_spheres = Volume.random_spheres(40, shape=512)
test_projections = ProjectionStack.from_volume(training_spheres, 9, 'full')

nnfbp_rec = test_projections.get_NNFBP_reconstruction(nnfbp)
msdnet_rec = test_projections.get_MSDNET_reconstruction(msdnet)
sirt_rec = test_projections.get_SIRT_reconstruction(force_positive_values=False)

Generation of the spheres: [████████████████████████████████████████████████████████████] 40/40 Est wait 00:0.00

Reconstruction part 1/2: [████████████████████████████████████████████████████████████] 8/8 Est wait 00:0.0.0

Reconstruction part 2/2: [████████████████████████████████████████████████████████████] 8/8 Est wait 00:0.0.0

MSDNET forward: [████████████████████████████████████████████████████████████] 512/512 Est wait 00:0.00



In [7]:
# MSEs

print("MSE NNFBP =", get_MSE_loss(test_spheres, nnfbp_rec))
print("MSE MSDNET =", get_MSE_loss(test_spheres, msdnet_rec))
print("MSE SIRT =", get_MSE_loss(test_spheres, sirt_rec))

MSE NNFBP = 0.011225931918368313
MSE MSDNET = 0.008260020356706584
MSE SIRT = 0.010212703419001875
