In [None]:
# CODE # 8 FOR SUMMER PROJECT: 19/08 REAL DATA PROCESSING

# IMPORT NECESSARY PACKAGES:

import numpy as np
import odlpet # PET imaging module
import odl # Reconstruction module
import torch # Deep learning module
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision # Deep learning for image analysis module
from torchvision import datasets, transforms
import matplotlib.pyplot as plt # Plotting module
torch.manual_seed(123);  # reproducibility
from odlpet.scanner.scanner import Scanner # Obtain PET scanner information
from odlpet.scanner.compression import Compression # Obtain PET scanner compression information
from odl.contrib.torch import OperatorAsModule # Module able to convert ODL operators into PyTorch tensor
import time # Chronometer
from odl.contrib import fom # Module for Figures of Merit (Similarity Indexes)
import nibabel # Work with real data

In [None]:
# Get mini-PET scanner parameters with ODLPET

scanner = Scanner() # Scanner operator

scanner.num_rings = 35 # Number of rings

compression = Compression(scanner) # Compress the scanner

# Select acquisition parameters
compression.max_num_segments = 0 # only direct sinograms, no trasverse
compression.num_of_views = 180 # angle resolution
compression.num_non_arccor_bins = 147 # tangential resolution
compression.data_arc_corrected = True
pet_projector = compression.get_projector(restrict_to_cylindrical_FOV=False) # domain resolution: 35 in z axis, 371 in x,y axis
pet_projector_adj=pet_projector.adjoint

In [None]:
pet_projector.range # Get the range of the PET sinogram space

In [None]:
pet_projector.domain # Get the range of the PET image space

In [None]:
# Obtain also the conversion into torch tensors for the forward operator and its adjoint
fwd_op_mod=OperatorAsModule(pet_projector)
fwd_op_adj_mod = OperatorAsModule(pet_projector_adj)

In [None]:
# Class for performing MLEM reconstruction
class MLEM(odl.operator.Operator):
    def __init__(self, op, niter):
        super(MLEM, self).__init__(domain=pet_projector.range, range=pet_projector.domain, linear=True)
        self.op = op
        self.niter = niter
    
    def _call(self, data):
        reco = self.range.one()
        odl.solvers.iterative.statistical.mlem(self.op, reco, data, niter=self.niter)
        return reco

In [None]:
# MLEM operators for comparison and to introduce in the networks, respectively

mlem_op_comp=MLEM(pet_projector,niter=10) # MLEM operator. 10 iterations. Comp=used as COMParison
mlem_op_comp_mod=OperatorAsModule(mlem_op_comp) # Modified MLEM operator to work with torch tensors

mlem_op_net=MLEM(pet_projector,niter=1) # MLEM operator. 1 iteration. Net=used in the NETwork
mlem_op_net_mod=OperatorAsModule(mlem_op_net) # Modified MLEM operator to work with torch tensors

In [None]:
# Real data loading
real=nibabel.load('filename.mnc') # put here the name of the file where measures are done
real_data = real.get_data() # Get miniPET data
real_data_ = np.transpose(real_data, (0,2,1)) # Adequate data dimensions


# Real data display
data = proj.range.element(real_data_) # Data display
data[data.asarray() < 0] = 0 # Set negative values to 0
data.show(coords=(1,None,None)) # Display direct sinograms

In [None]:
# Transform data into PyTorch tensor
data=torch.tensor(data)

In [None]:
# See data as fly-through images

size_sino=data.shape

# Axial slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_sino[0]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    axial_slice=data[ind,:,:]
    axes[cont].imshow(axial_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()


# Coronal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_sino[1]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    coronal_slice=data[:,ind,:]
    axes[cont].imshow(coronal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_sino[2]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    sagittal_slice=data[:,:,ind]
    axes[cont].imshow(sagittal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

In [None]:
# MLEM RECONSTRUCTIONS OF REAL DATA

# 1 ITERATION MLEM
mlem_1=mlem_op_net_mod(data)

# 10 ITERATIONS MLEM
mlem_10=mlem_op_comp_mod(data)

In [None]:
# See reconstructions as fly-through images

size_reco=mlem_1.shape

# 1 MLEM iteration reconstruction

# Axial slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[0]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    axial_slice=mlem_1[ind,:,:]
    axes[cont].imshow(axial_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Coronal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[1]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    coronal_slice=mlem_1[:,ind,:]
    axes[cont].imshow(coronal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Sagittal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[2]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    sagittal_slice=mlem_1[:,:,ind]
    axes[cont].imshow(sagittal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()


# 10 MLEM iteration reconstruction

# Axial slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[0]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    axial_slice=mlem_10[ind,:,:]
    axes[cont].imshow(axial_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Coronal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[1]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    coronal_slice=mlem_10[:,ind,:]
    axes[cont].imshow(coronal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Sagittal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[2]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    sagittal_slice=mlem_10[:,:,ind]
    axes[cont].imshow(sagittal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

In [None]:
# Define the architecture of the network: 2D U-NET

def double_conv(in_channels, out_channels): # Double convolution + Batch Normalization + ReLU
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),  # 2D convolution
        nn.BatchNorm2d(out_channels), # Batch normalization
        nn.ReLU(inplace=True), # Rectified Linear Unit
        nn.Conv2d(out_channels, out_channels, 3, padding=1), # 2D convolution
        nn.BatchNorm2d(out_channels), # Batch normalization
        nn.ReLU(inplace=True), # Rectified Linear Unit
    )   


class UNet(nn.Module): # Class for U-Net architecture

    def __init__(self):
        super().__init__()
                
        self.dconv_down1 = double_conv(1, 64) # One input channel, 64 output channels 
        self.dconv_down2 = double_conv(64, 128) # 64 input channels, 128 output channels 
        self.dconv_down3 = double_conv(128, 256) # 128 input channels, 256 output channels 
        self.dconv_down4 = double_conv(256, 512) # 256 input channels, 512 output channels 
        self.dconv_down5 = double_conv(512, 1024) # 512 input channels, 1024 output channels

        self.maxpool = nn.MaxPool2d(2,ceil_mode=True) # 2x pooling
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)   # 2x upsampling, with a bilinear interpolation and corner alignment     
        
        self.dconv_up4 = double_conv(512 + 1024, 512) # 512 + 1024 input channels, 512 output channels
        self.dconv_up3 = double_conv(256 + 512, 256) # 256 + 512 input channels, 256 output channels  
        self.dconv_up2 = double_conv(128 + 256, 128) # 128 + 256 input channels, 128 output channels 
        self.dconv_up1 = double_conv(128 + 64, 64) # 128 + 64 input channels, 128 output channels 
        
        self.conv_last = nn.Conv2d(64, 1, 1) # 64 input channels, one output channel and a kernel size of 1 (fully-connected layer)
        
        
    def forward(self, x):
        # Downsampling: Convolution + Pooling
        
        # Layer 1
        conv1 = self.dconv_down1(x) 
        x = self.maxpool(conv1)
        
        # Layer 2
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        # Layer 3
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)
        
        #conv4 = self.dconv_down4(x)
        #x = self.maxpool(conv4) 
        
        #x = self.dconv_down3(x)
        
        # Upsampling: Convolution + Upsampling with Bilinear Interpolation
        
        #x = self.upsample(x)
        
        #if x.shape[2]*x.shape[3]==conv4.shape[2]*conv4.shape[3]:
            #x = torch.cat([x, conv4], dim=1)
        #else:
         #   x = torch.cat([F.interpolate(x,size=(conv4.shape[2],conv4.shape[3])), conv4], dim=1)
            
        
        #x = self.dconv_up4(x)
        #x = self.upsample(x)    
        
        #if x.shape[2]*x.shape[3]==conv3.shape[2]*conv3.shape[3]:
        #    x = torch.cat([x, conv3], dim=1)
        #else:
         #    x = torch.cat([F.interpolate(x,size=(conv3.shape[2],conv3.shape[3])), conv3], dim=1)  
        
        
        
        #x = self.dconv_up3(x)
        #x = self.upsample(x)    
        
        if x.shape[2]*x.shape[3]==conv2.shape[2]*conv2.shape[3]:
            x = torch.cat([x, conv2], dim=1)
        else:
             x = torch.cat([F.interpolate(x,size=(conv2.shape[2],conv2.shape[3])), conv2], dim=1)     
        
        # Layer 2
        x = self.dconv_up2(x)
        x = self.upsample(x) 
        
        if x.shape[2]*x.shape[3]==conv1.shape[2]*conv1.shape[3]:
            x = torch.cat([x, conv1], dim=1)
        else:
             x = torch.cat([F.interpolate(x,size=(conv1.shape[2],conv1.shape[3])), conv1], dim=1)  
        
        # Layer 1
        
        x = self.dconv_up1(x)
        
        # Output layer
        
        out = self.conv_last(x)

        return out


In [None]:
# Define the architecture of the network: 3D U-NET

def double_conv(in_channels, out_channels): # Double convolution + Batch Normalization + ReLU
    return nn.Sequential(
        nn.Conv3d(in_channels, out_channels, 3, padding=1),  # 2D convolution
        nn.BatchNorm3d(out_channels), # Batch normalization
        nn.ReLU(inplace=True), # Rectified Linear Unit
        nn.Conv3d(out_channels, out_channels, 3, padding=1), # 2D convolution
        nn.BatchNorm3d(out_channels), # Batch normalization
        nn.ReLU(inplace=True), # Rectified Linear Unit
    )   

 
class UNet3d(nn.Module): # 3D U-Net architecture

    def __init__(self):
        super().__init__()
                
        self.dconv_down1 = double_conv(1, 64) # One input channel, 64 output channels 
        self.dconv_down2 = double_conv(64, 128) # 64 input channels, 128 output channels 
        self.dconv_down3 = double_conv(128, 256) # 128 input channels, 256 output channels 
        self.dconv_down4 = double_conv(256, 512) # 256 input channels, 512 output channels
        #self.dconv_down5 = double_conv(512, 1024) # 512 input channels, 1024 output channels

        self.maxpool = nn.MaxPool3d(2,ceil_mode=True) # 2x pooling
        self.upsample = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)   # 2x upsampling, with a bilinear interpolation and corner alignment     
        
        #self.dconv_up4 = double_conv(512 + 1024, 512) # 512 + 1024 input channels, 512 output channels
        self.dconv_up3 = double_conv(256 + 512, 256) # 256 + 512 input channels, 256 output channels 
        self.dconv_up2 = double_conv(128 + 256, 128) # 128 + 256 input channels, 128 output channels
        self.dconv_up1 = double_conv(128 + 64, 64) # 128 + 64 input channels, 128 output channels
        
        self.conv_last = nn.Conv3d(64, 1, 1) # 64 input channels, one output channel and a kernel size of 1 (fully-connected layer)
        
        
    def forward(self, x):
        # Downsampling: Convolution + Pooling
        
        # Layer 1
        conv1 = self.dconv_down1(x) 
        x = self.maxpool(conv1)
        
        # Layer 2
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        # Layer 3
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   
        
        # Layer 4
        x = self.dconv_down4(x)
        
        # Upsampling
        
        # Layer 4
        x = self.upsample(x)
        
        if x.shape[2]*x.shape[3]*x.shape[4]==conv3.shape[2]*conv3.shape[3]*conv3.shape[4]:
            x = torch.cat([x, conv3], dim=1)
        else:
            x = torch.cat([F.interpolate(x,size=(conv3.shape[2],conv3.shape[3],conv3.shape[4])), conv3], dim=1)
        
        
        # Layer 3
        x = self.dconv_up3(x)
        x = self.upsample(x)    
        
        if x.shape[2]*x.shape[3]*x.shape[4]==conv2.shape[2]*conv2.shape[3]*conv2.shape[4]:
            x = torch.cat([x, conv2], dim=1)
        else:
             x = torch.cat([F.interpolate(x,size=(conv2.shape[2],conv2.shape[3],conv2.shape[4])), conv2], dim=1)     
        
        # Layer 2
        x = self.dconv_up2(x)
        x = self.upsample(x) 
        
        if x.shape[2]*x.shape[3]*x.shape[4]==conv1.shape[2]*conv1.shape[3]*conv1.shape[4]:
            x = torch.cat([x, conv1], dim=1)
        else:
             x = torch.cat([F.interpolate(x,size=(conv1.shape[2],conv1.shape[3],conv1.shape[4])), conv1], dim=1)  
        
        # Layer 1
        x = self.dconv_up1(x)
        
        
        out = self.conv_last(x)

        return out


In [None]:
# Load NNs trained with synthetic data
low_resolution_2d_ellipsoids_net=torch.load('2d_28x28_denoise_ellipsoids.torch') # 2D network with low resolution ellipsoids
low_resolution_2d_mnist_net=torch.load('2d_28x28_denoise_mnist.torch') # 2D network with MNIST images at mini-PET resolution
pet_resolution_2d_ellipsoids_net=torch.load('2d_PET_Geometry_ellipsoids_denoise_network.torch') # 2D network with ellipsoids at mini-PET geometry
pet_resolution_2d_mnist_net=torch.load('2d_PET_Geometry_mnist_denoise_network.torch') # 2D network with MNIST images at mini-PET resolution
denoise_net_3d=torch.load('3d_denoise.torch') # 3D denoising network with ellipsoids

In [None]:
# Obtain denoised versions of the real images

test=torch.tensor(np.zeros((1,1,mlem_1.shape[0],mlem_1.shape[1],mlem_1.shape[2]))).float().cuda() # Real volume to be denoised
test[0,0,:,:,:]=mlem_1
denoised_low_resolution_2d_ellipsoids=torch.tensor(np.zeros((1,1,mlem_1.shape[0],mlem_1.shape[1],mlem_1.shape[2]))).float().cuda() # Denoised version of the 2D ellipsoid network at low resolution
denoised_low_resolution_2d_mnist=torch.tensor(np.zeros((1,1,mlem_1.shape[0],mlem_1.shape[1],mlem_1.shape[2]))).float().cuda() # Denoised version of the 2D MNIST network at low resolution
denoised_pet_resolution_2d_ellipsoids=torch.tensor(np.zeros((1,1,mlem_1.shape[0],mlem_1.shape[1],mlem_1.shape[2]))).float().cuda() # Denoised version of the 2D ellipsoid network at mini-PET resolution
denoised_pet_resolution_2d_mnist=torch.tensor(np.zeros((1,1,mlem_1.shape[0],mlem_1.shape[1],mlem_1.shape[2]))).float().cuda() # Denoised version of the 2D MNIST network at mini-PET resolution

# Test 2D denoisers slice-by-slice
for i in range(mlem_1.shape[0]):
    denoised_low_resolution_2d_ellipsoids[0,0,i,:,:]=low_resolution_2d_ellipsoids_net(test[0,0,i,:,:])
    denoised_low_resolution_2d_mnist[0,0,i,:,:]=low_resolution_2d_mnist_net(test[0,0,i,:,:])
    denoised_pet_resolution_2d_ellipsoids[0,0,i,:,:]=pet_resolution_2d_ellipsoids_net(test[0,0,i,:,:])
    denoised_pet_resolution_2d_mnist[0,0,i,:,:]=pet_resolution_2d_mnist_net(test[0,0,i,:,:])

# Test 3D denoiser
denoised_3d=denoise_net_3d(test)

In [None]:
# Result comparison with fly-through images

# 10 iterations MLEM reconstruction
# Axial slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[0]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    axial_slice=mlem_10[ind,:,:]
    axes[cont].imshow(axial_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Coronal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[1]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    coronal_slice=mlem_10[:,ind,:]
    axes[cont].imshow(coronal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Sagittal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[2]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    sagittal_slice=mlem_10[:,:,ind]
    axes[cont].imshow(sagittal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()


# Denoising with ellipsoid network (2D low resolution)

# Axial slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[0]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    axial_slice=denoised_low_resolution_2d_ellipsoids[ind,0,:,:].cpu().detach().numpy()
    axes[cont].imshow(axial_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Coronal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[1]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    coronal_slice=denoised_low_resolution_2d_ellipsoids[:,0,ind,:].cpu().detach().numpy()
    axes[cont].imshow(coronal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Sagittal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[2]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    sagittal_slice=denoised_low_resolution_2d_ellipsoids[:,0,:,ind].cpu().detach().numpy()
    axes[cont].imshow(sagittal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()










In [None]:
# 10 iterations MLEM reconstruction

# Axial slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[0]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    axial_slice=mlem_10[ind,:,:]
    axes[cont].imshow(axial_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Coronal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[1]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    coronal_slice=mlem_10[:,ind,:]
    axes[cont].imshow(coronal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Sagittal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[2]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    sagittal_slice=mlem_10[:,:,ind]
    axes[cont].imshow(sagittal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()



# Denoising with MNIST network (2D low resolution)

# Axial slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[0]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    axial_slice=denoised_low_resolution_2d_mnist[ind,0,:,:].cpu().detach().numpy()
    axes[cont].imshow(axial_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Coronal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[1]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    coronal_slice=denoised_low_resolution_2d_mnist[:,0,ind,:].cpu().detach().numpy()
    axes[cont].imshow(coronal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Sagittal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[2]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    sagittal_slice=denoised_low_resolution_2d_mnist[:,0,:,ind].cpu().detach().numpy()
    axes[cont].imshow(sagittal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()


In [None]:
# 10 iterations MLEM reconstruction

# Axial slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[0]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    axial_slice=mlem_10[ind,:,:]
    axes[cont].imshow(axial_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Coronal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[1]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    coronal_slice=mlem_10[:,ind,:]
    axes[cont].imshow(coronal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Sagittal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[2]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    sagittal_slice=mlem_10[:,:,ind]
    axes[cont].imshow(sagittal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()


# Denoising with ellipsoid network (2D PET resolution)

# Axial slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[0]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    axial_slice=denoised_pet_resolution_2d_ellipsoids[ind,0,:,:].cpu().detach().numpy()
    axes[cont].imshow(axial_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Coronal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[1]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    coronal_slice=denoised_pet_resolution_2d_ellipsoids[:,0,ind,:].cpu().detach().numpy()
    axes[cont].imshow(coronal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Sagittal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[2]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    sagittal_slice=denoised_pet_resolution_2d_ellipsoids[:,0,:,ind].cpu().detach().numpy()
    axes[cont].imshow(sagittal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()


In [None]:
# 10 iterations MLEM reconstruction

# Axial slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[0]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    axial_slice=mlem_10[ind,:,:]
    axes[cont].imshow(axial_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[1]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    coronal_slice=mlem_10[:,ind,:]
    axes[cont].imshow(coronal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Sagittal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[2]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    sagittal_slice=mlem_10[:,:,ind]
    axes[cont].imshow(sagittal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()



# Denoising with ellipsoid network (3D PET resolution)


# Axial slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[0]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    axial_slice=denoised_3d[ind,0,:,:].cpu().detach().numpy()
    axes[cont].imshow(axial_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Coronal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[1]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    coronal_slice=denoised_3d[:,0,ind,:].cpu().detach().numpy()
    axes[cont].imshow(coronal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()

# Sagittal slices
fig, axes=plt.subplots(nrows=1,ncols=10,figsize=(30,50))
slice_indexes=np.linspace(0,size_reco[2]-1,10).astype(int)
cont=0
for ind in slice_indexes:
    sagittal_slice=denoised_3d[:,0,:,ind].cpu().detach().numpy()
    axes[cont].imshow(sagittal_slice,cmap='gray')
    axes[cont].axis('off')
    cont+=1

plt.show()


In [None]:
# Select a homogeneous region of the state-of-the-art and of the denoised versions
# Compare them by computing the noise in those regions (Coefficient of Variation)
noise_mlem=np.mean(mlem_10[, ,].flatten())/np.std(mlem_10[, ,].flatten()) # Noise in 10 iterations MLEM
print(noise_mlem)
noise_mnist=np.mean(denoised_low_resolution_2d_mnist[, ,].cpu().detach().numpy().flatten())/np.std(mlem_10[, ,].cpu().detach().numpy().flatten()) # Noise in MNIST database denoised images
print(noise_mnist)
noise_ellipsoids_low_resolution=np.mean(denoised_low_resolution_2d_ellipsoids[, ,].cpu().detach().numpy().flatten())/np.std(denoised_low_resolution_2d_ellipsoids[, ,].cpu().detach().numpy().flatten())
# Noise in ellipsoid network at low resolution in 2D
print(noise_ellipsoids_low_resolution)
noise_ellipsoids_pet_resolution=np.mean(denoised_pet_resolution_2d_ellipsoids[, ,].cpu().detach().numpy().flatten())/np.std(denoised_pet_resolution_2d_ellipsoids[, ,].cpu().detach().numpy().flatten())
# Noise in ellipsoid network at mini-PET resolution in 2D
print(noise_ellipsoids_pet_resolution)
noise_3d=np.mean(denoised_3d[, ,].cpu().detach().numpy().flatten())/np.std(denoised_3d[, ,].cpu().detach().numpy().flatten()) # Noise in 3D denoising network
print(noise_3d)

In [None]:
# Process noise values in some Excel file, for example
