Setting up a training loop with validation.

This demo is a jupyter notebook, i.e. intended to be run step by step.

Author: Imraj Singh

First version: 13th of May 2022

CCP SyneRBI Synergistic Image Reconstruction Framework (SIRF).
Copyright 2022 University College London.

This is software developed for the Collaborative Computational Project in Synergistic Reconstruction for Biomedical Imaging (http://www.ccpsynerbi.ac.uk/).

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Setting up the training

This is very standard pytorch parlance...

The ideas:
* Setup acquisition model
* Setup dataset
* Setup model
* Setup training loop with validation

In [None]:
# Import the PET reconstruction engine
import sirf.STIR as pet
# Set the verbosity
pet.set_verbosity(0)
# Store tempory sinograms in RAM
pet.AcquisitionData.set_storage_scheme("memory")
# SIRF STIR message redirector
import sirf
msg = sirf.STIR.MessageRedirector(info=None, warn=None, errr=None)
# Load dataset and model
from odl_funcs.ellipses import EllipsesDataset
from lpd_net import LearnedPrimalDual
# Import standard extra packages
import matplotlib.pyplot as plt
import os
import numpy as np
import time
import torch
from tqdm.notebook import trange, tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

size_xy = 128
batch = 10
from sirf.Utilities import examples_data_path
sinogram_template = pet.AcquisitionData(examples_data_path('PET')\
                                        + '/thorax_single_slice/template_sinogram.hs');
# create acquisition model
acq_model = pet.AcquisitionModelUsingParallelproj();
image_template = sinogram_template.create_uniform_image(1.0,size_xy);
acq_model.set_up(sinogram_template,image_template);
train_dataloader = torch.utils.data.DataLoader( \
    EllipsesDataset(acq_model.forward, image_template, mode="train", n_samples = 50) \
    , batch_size=batch, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader( \
    EllipsesDataset(acq_model.forward, image_template, mode="valid") \
    , batch_size=1, shuffle=True)
model = LearnedPrimalDual(image_template, sinogram_template,\
                          acq_model, n_iter = 5, n_primal = 5, n_dual = 5).to(device)

Let's set up the training loop with validation and a very simply "data logger"

In [None]:
lr = 1e-6
total_epochs = 500

criterion = torch.nn.MSELoss(reduction='sum').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.99, 0.999))

data_log = {}
data_log["valid_loss"] = []
data_log["valid_image"] = []
data_log["loss"] = []

min_valid_loss = 1e9

x_gt_valid, y_valid = next(iter(valid_dataloader))
x_gt_valid, y_valid = x_gt_valid.float().to(device), y_valid.float().to(device)
model.load_state_dict(torch.load('quasi_trained_model.torch_model'))
pbar1 = trange(total_epochs, position=0, leave=True, desc='Epochs')
pbar2 = tqdm(train_dataloader, position=1, leave=True, desc='Iterations')
for i in pbar1:
    model.eval()
    x_valid = model(y_valid);
    loss_valid = criterion(x_gt_valid, x_valid)
    pbar1.set_description("Epoch, loss {:10.2f}".format(loss_valid.item()))
    data_log["valid_loss"].append(loss_valid.item())
    data_log["valid_image"].append(x_valid[0,0,...].detach().cpu().numpy())
    if min_valid_loss > loss_valid.item():
        best_model = model.state_dict()
    pbar2.reset(5)
    for ii, (x_gt, y) in enumerate(train_dataloader):
        x_gt, y = x_gt.float().to(device), y.float().to(device)
        model.train();
        # Forward pass: Compute predicted y by passing x to the model
        x = model(y);
        # Compute and print loss
        loss = criterion(x_gt, x)
        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        data_log["loss"].append(loss.item())
        pbar2.update()
        pbar2.set_description("Batch sample, loss {:10.2f}".format(loss.item()))
    
torch.save(best_model, 'quasi_trained_model.torch_model')
np.save('results',data_log)

Let's look as some of the results

In [None]:
r = np.load('results.npy',allow_pickle=True).item()
#plt.plot(r['loss'])
#plt.plot(np.log(r['valid_loss']))

min_loss = np.argmin(r['valid_loss'])
print(min_loss)
print(r['valid_loss'][min_loss])
plt.imshow(r['valid_image'][min_loss])

# Exercises

* Do inference of BrainWeb data (i.e. add a test set)
* See the reconstruction at various points in the network
* Compare with OSEM reconstruction
* Attempt to improve the model (change training parameters?)
* Use "realistic" acquisition model for training data generation, but "simple" acquisition model for reconstruction
* Add more physics to the forward model?