Setting up the DL model.

This demo is a jupyter notebook, i.e. intended to be run step by step.

Author: Imraj Singh

First version: 20th of May 2022

CCP SyneRBI Synergistic Image Reconstruction Framework (SIRF).
Copyright 2022 University College London.

This is software developed for the Collaborative Computational Project in Synergistic Reconstruction for Biomedical Imaging (http://www.ccpsynerbi.ac.uk/).

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Setting up the model

This implementation is taken from [Learned Primal Dual PyTorch implementation](https://github.com/cetmann/pytorch-primaldual) and changed somewhat for our needs.

First we import the prerequisite packages.

In [None]:
# Import the PET reconstruction engine
import sirf.STIR as pet
# Set the verbosity
pet.set_verbosity(1)
# Store tempory sinograms in RAM
pet.AcquisitionData.set_storage_scheme("memory")
# Import a file that can generate the shepp logan phantom
from odl_funcs.ellipses import EllipsesDataset
import sirf
msg = sirf.STIR.MessageRedirector(info=None, warn=None, errr=None)
# Import standard extra packages
import matplotlib.pyplot as plt
import os
import numpy as np
import time
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

size_xy = 128
from sirf.Utilities import examples_data_path
sinogram_template = pet.AcquisitionData(examples_data_path('PET')\
                                        + '/thorax_single_slice/template_sinogram.hs');
# create acquisition model
acq_model = pet.AcquisitionModelUsingParallelproj();
image_template = sinogram_template.create_uniform_image(1.0,size_xy);
acq_model.set_up(sinogram_template,image_template);

Let's investigate the number of parameters used in the network

In [None]:
from lpd_net import LearnedPrimalDual

# Import model
model = LearnedPrimalDual(image_template, sinogram_template,\
                          acq_model, n_iter = 5, n_primal = 5, n_dual = 5).to(device)
params = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        params += len(param.data) 
print(str(params) + " trainable parameters")

How would one see the trainable parameters in each part of the neural network?

In [None]:
# Answer below!



























for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, len(param.data))