The following code serves as a representation of the workflow followed during this research internship. 

The goal is to exploit a dataset of full body PET MRI (T2 and dixon sequences) and CT to build and train a model which can create synthetic CT from the MRI images. The goal then using these sCT is to use them for body composition:  https://github.com/UMEssen/Body-and-Organ-Analysis

Body composition is a biomarker which can be used to determine treatment plans in oncology, cardiology. The CT image is fed to a software which will separate it into different regions thanks to a model that thresholds the HU to a specific intensity range. The regions to be determine are : 
- Subcutaneous adipose tissue
- Total adipose tissue
- Visceral adipose tissue
- Muscle volume

The software used also allows an organ segmentation of the trunk whiwh can also be later tested with the sCT. 

The advantage of creating such a sCT is to be able to be able to produce this body composition report without the need for an irradiating scan. Past models for creating sCT at the Bordet Institute were intended for radiotherapy treatments. The main difference here is that there is no need for such precision for body composition because we are more interested in global composition then the exact position of every organ in the scan. 

From the time being, the data is not yet recovered, the following code was used to explore torchio functionalities and learn how to load medical data into a model. 

In [1]:
#imports
import torch 
import torchio as tio
from torch.utils.data import DataLoader
import pydicom
import os
import numpy as np
from pydicom.pixel_data_handlers.util import apply_voi_lut
import matplotlib.pyplot as plt
%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


### Load the data using torchio 

It is often recommended to use NIfTI format for managing medical images, why? what are the advantages compared with simply working with dicom format?

In [4]:
# Function that will randomly split my data into 2 categories: training and testing 
# The validation set will be taken directly from the training set
def random_split(subjects, ratio=0.9):
    num_subjects = len(subjects)
    num_training_subjects = int(ratio * num_subjects)
    num_test_subjects = num_subjects - num_training_subjects

    num_split_subjects = num_training_subjects, num_test_subjects
    return torch.utils.data.random_split(subjects, num_split_subjects)



def Create_dataset(rootdir):
    file_paths = []
    subject_paths = []
    subjects_list = []
    nb_subjects = 0
    
    # Start by recovering the path to the subjects
    for subjects in os.listdir(rootdir):
        subject_path = os.path.join(rootdir,subjects)
        subject_paths.append(subject_path)
        nb_subjects += 1
    
   
    for subject_path in subject_paths:
        # Recover all the files
        ct_path = None
        mri_path = None
        
        for subfolder in os.listdir(subject_path):
            subfolder_path= os.path.join(subject_path,subfolder)
            if subfolder.startswith('CT'):
                ct_path = subfolder_path
            elif subfolder.startswith('IRM'):
                mri_path = subfolder_path
        
            
        # Create the subject format of torchio witht the scans and the id of the patient
        sub = tio.Subject(
            CT = tio.ScalarImage(ct_path),
            MRI = tio.ScalarImage(mri_path),
            id_patient = os.path.basename(subject_path),

        ) 
        
        subjects_list.append(sub)
        
    # List of transforms : for training, possibility of data augmentation!
    # Careful when rescaling intensity with multiple images (t1,t2...), use include to select one type of image
    transforms_train = [
        tio.ToCanonical(),
        tio.Clamp(out_min=0,out_max = 2500),
        #tio.RescaleIntensity(out_min_max=(0,1)),
        tio.RescaleIntensity(out_min_max=(0, 1), include=(['CT'])),
        tio.RescaleIntensity(out_min_max=(0, 1), include=(['IRM'])),
        tio.RandomFlip(p=0.2), #data augmentation
        tio.RandomAffine(scales=(0.9, 1.2), degrees=15, p=0.2),  #data augmentation
    ]
    
    transforms_test = [
        tio.ToCanonical(),
        tio.Clamp(out_min=0,out_max = 2500),
        tio.RescaleIntensity(out_min_max=(0,1)),
    ]

    # Create our own set of tranforms that we will apply to our dataset
    transfo_tr = tio.Compose(transforms_train)
    transfo_te = tio.Compose(transforms_test)
    
    
    # Random split of the subjects with a defined ratio
    train_val_subjects, testing_subjects = random_split(subjects_list, ratio=0.9)
    training_subjects, validation_subjects = random_split(train_val_subjects, ratio = 0.9)
    
    # Create the final datasets made from the different subjects
    training_dataset = tio.SubjectsDataset(training_subjects, transform=transfo_tr)
    validation_dataset = tio.SubjectsDataset(validation_subjects, transform = transfo_tr)
    testing_dataset = tio.SubjectsDataset(testing_subjects, transform = transfo_te)

    
    dataset = tio.SubjectsDataset(subjects_list, transform = transfo_tr)
    #print('Dataset successfully created! The set is made of',nb_subjects,'patients.')
    print('Datasets successfully created!')
    print('=======================================================')
    print(' Training set is made of', len(training_dataset), 'images \n Validation set is made of', len(validation_dataset), 'images \n', 'Testing set is made of', len(testing_dataset), 'images')
    print('=======================================================')
    ids = [subject.id_patient for subject in dataset.dry_iter()]
    
    return dataset
    #return training_dataset, validation_dataset, testing_dataset, file_paths

In [5]:
#training_dataset, validation_dataset, testing_dataset, file_paths = Create_dataset('../Clara intern/whole body')

dataset = Create_dataset('../Clara intern/whole body')

# The plot of a Subject object will build the coronal and axial view when we only send it the sagittal view
#dataset[0].plot()

Datasets successfully created!
 Dataset is made of 1 images


Now that the dataset is crated with 3D images for each subject, we need to decide how to exploit them: do we keep them 3D in the model or do we explore them slice by slice to achieve a 2D analysis?

The following part will explore how this dataset can be included and preprocessed before feeding it to the model. 
It is important to specify that with the real data, some pre processing steps might be done using mice software to re-align the PET MRI with the CT (make sure they have the same zero and that morphological differences are not too disturbing).

A crucial notion to this project is to make sure that the model and preprocessing steps chosen will match the concrete clinical applications. 

The issue with medical images is their quantity, they often contain hundreds of milions of voxels and cannot always be downsampled. 
Big differences are to be noted when working with medical images: their size, the fact that they might be 3D, their format (often DICOM which contains metadata about the patient), the fact that they cannot be easily downsampled if details are needed. 

The batch size of medical images tend to be way smaller than the usual ones because of the quantity of information contained in a single medical image. 


### Create patches

To train in 2D, one need to extract slices from 3D volumes, aggrefating the inference results to generate a 3D volume: this is called batch based training, the patches along a dimension is one (cite Torchio patch based pipeline). 

In torchio it is possible to use patch samplers: functions that will randomly extract pathces from volumes when fed a SubjectDataset like we created earlier. 
We tend to use batch sampling when working with medical images because of their size: working with smaller patches reduce computation. It has also been proven that soemtimes, algorithms using patches can be more efficient, it is the case for denoising for example. 

You could chose Uniform or Weighted patching: uniform will take random patches from a volume with a uniform probabiliy while the weighted sampler will randomly extract patches given a probability map. 
If you chose very small patches, it could be intersting to create a probability map for each slice in order to focus on the region of interest and not consider much the background. 
However with larger patches, a uniform sampler is easier to use. 



In [7]:
# UNIFORM sampler
# Chose the sampler for your data
patch_size = (64,64,8)
sampler_uniform = tio.data.UniformSampler(patch_size)

In [49]:
# WEIGHTED sampler 

# Constantes 
probability_map = []
SIZE_IMAGE = 256
threshold = 0.5


# For each slice, create a matching array that takes 1 when the value is not background 
slices = [pydicom.read_file(images_dir) for images_dir in file_paths]
for slice in slices:
    data_slice = apply_voi_lut(slice.pixel_array, slice)
    proba_map = np.zeros((SIZE_IMAGE,SIZE_IMAGE))
    proba_map[data_slice > threshold] = 1
    probability_map.append(proba_map)

    
# Create a stack of these 2D arrays to feed them to the sampler
probability_map_stack =np.stack(probability_map, axis=0)

sampler = tio.data.WeightedSampler(patch_size=(32, 32, 5), probability_map = probability_map_stack)


In [102]:
samples_per_volume=4
max_length=200 
num_workers=8


patches_training_set = tio.Queue(
        subjects_dataset=training_dataset,
        max_length=max_length, # Maximum number of patches that can be stored in the queue
        samples_per_volume=samples_per_volume, # Number if patches to be extracted from each volume
        sampler=sampler, # The sampler that we defined previously
        num_workers=num_workers, # Number of subprocesses to use for data loading
        shuffle_subjects=True,
        shuffle_patches=True,
    )

patches_validation_set = tio.Queue(
    subjects_dataset=validation_dataset,
    max_length=max_length,
    samples_per_volume=1,
    sampler=sampler,
    num_workers=num_workers,
    shuffle_subjects=True,
    shuffle_patches=True,
)

# Define loader, no need to shuffle it as the queue already does it
training_loader_patches = DataLoader(
    patches_training_set,
    batch_size = 16,
    num_workers = 0, # Need to instianciate it to 0 for Loaders
    drop_last = True,
)

validation_loader_patches = DataLoader(
    patches_validation_set,
    batch_size = 16,
    num_workers = 0, # Need to instianciate it to 0 for Loaders
    drop_last = True,
)

### Definition of the model

In the following section we will define the model that we will later train. As the scans are considered paired, the model used is a Residual Neural Network which is an architecture in which the weight layers learn residual functions with reference to the layer inputs.

The following code is adapted from a 2D ResNet. 

In [None]:
import os
import random
import time
from collections import OrderedDict
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torchio as tio
from torch.utils.tensorboard import SummaryWriter

#import network as network
import network_ResNet as network
from matplotlib import pyplot as plt
from util import print_log, format_train_log, format_validation_log


class Model(object):
    def __init__(self, expr_dir, seed=None, batch_size=None,
                 epoch_count=1, niter=150, niter_decay=50, beta1=0.5, 
                 lr=0.0003, ngf=64, n_blocks=9, input_nc=1, output_nc=1, 
                 use_dropout=True, norm='batch', max_grad_norm=500.,
                 monitor_grad_norm=True, save_epoch_freq=10, print_freq=15, 
                 display_epoch_freq=1, testing=False, resume=False):

        self.expr_dir = expr_dir
        self.seed = seed
        self.device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
        self.batch_size = batch_size

        self.epoch_count = epoch_count
        self.niter = niter
        self.niter_decay = niter_decay
        self.beta1 = beta1
        self.lr = lr
        self.old_lr = self.lr

        self.ngf = ngf
        self.n_blocks = n_blocks
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.use_dropout = use_dropout
        self.norm = norm
        self.max_grad_norm = max_grad_norm

        self.monitor_grad_norm = monitor_grad_norm
        self.save_epoch_freq = save_epoch_freq
        self.print_freq = print_freq
        self.display_epoch_freq = display_epoch_freq
        self.time = datetime.now().strftime("%Y%m%d-%H%M%S")

        # define network we need here
        self.netG = network.define_generator(input_nc=self.input_nc, 
                                             output_nc=self.output_nc, 
                                             ngf=self.ngf,
                                             n_blocks=self.n_blocks, 
                                             use_dropout=self.use_dropout,
                                             device=self.device)

        # define all optimizers here
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=self.lr, 
                                            betas=(self.beta1, 0.999),
                                            )

        self.loss = nn.MSELoss()

        self.resume = resume

        if not os.path.exists(expr_dir):
            os.makedirs(expr_dir)

        if not os.path.exists(os.path.join(expr_dir, 'TensorBoard')):
            os.makedirs(os.path.join(expr_dir, 'TensorBoard', self.time))

        if not os.path.exists(os.path.join(expr_dir, 'TensorBoard', self.time, 'training_visuals')):
            os.makedirs(os.path.join(expr_dir, 'TensorBoard', self.time, 'training_visuals'))

        if not os.path.exists(os.path.join(expr_dir, 'TensorBoard', self.time, 'testing_visuals')):
            os.makedirs(os.path.join(expr_dir, 'TensorBoard', self.time, 'testing_visuals'))

        if not testing:
            num_params = 0
            with open("%s/nets.txt" % self.expr_dir, 'w') as nets_f:
                num_params += network.print_network(self.netG, nets_f)
                nets_f.write('# parameters: %d\n' % num_params)
                nets_f.flush()

        if resume:
            self.load(os.path.join(self.expr_dir, "latest"), True)
            self.netG.to(self.device)

    def train(self, train_dataset, validation_set):
        self.batch_size = train_dataset.batch_size
        self.save_options()
        out_f = open(f"{self.expr_dir}/results.txt", 'w')
        use_gpu = torch.cuda.is_available()

        tensorboard_writer = SummaryWriter(os.path.join(self.expr_dir, 'TensorBoard', self.time))

        if self.seed is not None:
            print(f"using random seed: {self.seed}")
            random.seed(self.seed)
            np.random.seed(self.seed)
            torch.manual_seed(self.seed)
            if use_gpu:
                torch.cuda.manual_seed_all(self.seed)

        total_steps = 0
        print_start_time = time.time()

        for epoch in range(self.epoch_count, self.niter + self.niter_decay + 1):
            epoch_start_time = time.time()
            epoch_iter = 0
            total_loss = 0
            for data in train_dataset:
                b0_guidance = data['b0guidance'][tio.DATA].to(self.device)
                b0_guidance = b0_guidance.transpose_(2, 4)
                b0_guidance = torch.squeeze(b0_guidance, dim=2)
                b800_noisy = data['b800noisy'][tio.DATA].to(self.device)
                b800_noisy = b800_noisy.transpose_(2, 4)
                b800_noisy = torch.squeeze(b800_noisy, dim=2)
                b800_reference = data['b800reference'][tio.DATA].to(self.device)
                b800_reference = b800_reference.transpose_(2, 4)
                b800_reference = torch.squeeze(b800_reference, dim=2)
                total_steps += self.batch_size
                epoch_iter += self.batch_size
                
                if self.monitor_grad_norm:
                    losses, visuals, _ = self.train_instance(b0_guidance, b800_noisy, b800_reference)
                else:
                    losses, visuals = self.train_instance(b0_guidance, b800_noisy, b800_reference)
                
                total_loss += losses['Loss']
                    
            loss = total_loss/len(train_dataset)
                
            if total_steps % self.print_freq == 0:
                t = (time.time() - print_start_time) / self.batch_size
                print_log(out_f, format_train_log(epoch, epoch_iter, losses, t))
                tensorboard_writer.add_scalars('Loss', {'train': loss}, total_steps)
                print_start_time = time.time()

            print_start_time = time.time()
            
            if epoch % self.display_epoch_freq == 0:
                self.visualize(visuals, data['b0guidance'][tio.AFFINE], epoch, 
                               epoch_iter / self.batch_size)

            if epoch % self.save_epoch_freq == 0:
                print_log(out_f, 'saving the model at the end of epoch %d, iterations %d' % (epoch, total_steps))
                self.save('latest')

                losses_validation, visuals_validation, affine = self.validation(validation_set)

                self.visualize(visuals_validation, affine, epoch, epoch_iter / self.batch_size, "testing")
                t = (time.time() - print_start_time) / self.batch_size
                print_log(out_f, format_validation_log(epoch, epoch_iter, losses_validation, t))

                tensorboard_writer.add_scalars('Loss', {'Test': losses_validation['Loss']},
                                               total_steps)
            print_log(out_f, 'End of epoch %d / %d \t Time Taken: %d sec' %
                      (epoch, self.niter + self.niter_decay, time.time() - epoch_start_time))

            if epoch > self.niter:
                self.update_learning_rate()

        out_f.close()
        tensorboard_writer.close()


    def train_instance(self, b0_guidance, b800_noisy, b800_reference):
        
        inpt = torch.cat((b800_noisy, b0_guidance), dim=1)
        b800_denoised = self.netG.forward(inpt)

        self.optimizer_G.zero_grad()
        loss = self.loss(b800_denoised, b800_reference)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.netG.parameters(), 
                                                   self.max_grad_norm)
        self.optimizer_G.step()

        losses = OrderedDict([('Loss', loss.data.item())])
        visuals = OrderedDict([('Noisy_b800', b800_noisy.data),
                               ('Denoisy_b800', b800_denoised.data),
                               ('Reference_b800', b800_reference.data)
                               ])
        if self.monitor_grad_norm:
            grad_norm = OrderedDict([('grad_norm', grad_norm)])

            return losses, visuals, grad_norm

        return losses, visuals
    
    
    
    def validation(self, validation_set):
        
        self.netG.eval()
        
        total_loss = 0
        with torch.no_grad():
            for data in validation_set:
                
                b0_guidance = data['b0guidance'][tio.DATA].to(self.device)
                b0_guidance = b0_guidance.transpose_(2, 4)
                b0_guidance = torch.squeeze(b0_guidance, dim=2)
                b800_noisy = data['b800noisy'][tio.DATA].to(self.device)
                b800_noisy = b800_noisy.transpose_(2, 4)
                b800_noisy = torch.squeeze(b800_noisy, dim=2)
                b800_reference = data['b800reference'][tio.DATA].to(self.device)
                b800_reference = b800_reference.transpose_(2, 4)
                b800_reference = torch.squeeze(b800_reference, dim=2)
                
                inpt = torch.cat((b800_noisy, b0_guidance), dim=1)
                b800_denoised = self.netG.forward(inpt)
                
                loss_val = self.loss(b800_denoised, b800_reference)
                total_loss += loss_val.item()

        total_loss = total_loss / len(validation_set)
        
        losses = OrderedDict([('Loss', total_loss)])
        visuals = OrderedDict([('Noisy_b800', b800_noisy.data),
                               ('Denoisy_b800', b800_denoised.data),
                               ('Reference_b800', b800_reference.data),
                               ])
        
        return losses, visuals, data['b0guidance'][tio.AFFINE]



    def update_learning_rate(self):
        lrd = self.lr / self.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr

        print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr



    def save(self, checkpoint_name):
        checkpoint_path = os.path.join(self.expr_dir, checkpoint_name)
        checkpoint = {
            'netG': self.netG.state_dict(),
            'optimizer_G': self.optimizer_G.state_dict()
        }
        torch.save(checkpoint, checkpoint_path)
        
        
        
    def visualize(self, visuals, affine, epoch, index, state="training"):
        b800_noisy = visuals['Noisy_b800'].cpu().transpose_(1, 3)
        b800_denoised = visuals['Denoisy_b800'].cpu().transpose_(1, 3)
        b800_reference = visuals['Reference_b800'].cpu().transpose_(1, 3)
        
        b800_noisy = b800_noisy[:,None,:,:,:]
        b800_denoised = b800_denoised[:,None,:,:,:]
        b800_reference = b800_reference[:,None,:,:,:]
        
        for i in range(1):
            subject = tio.Subject(
                Noisy=tio.ScalarImage(tensor=b800_noisy[i], affine=affine[i]),
                Denoisy=tio.ScalarImage(tensor=b800_denoised[i], affine=affine[i]),
                Reference=tio.ScalarImage(tensor=b800_reference[i], affine=affine[i]),
            )

            save_path = os.path.join(self.expr_dir, 'TensorBoard', self.time, state + "_visuals")
            save_path = os.path.join(save_path, 'cycle_' + str(epoch) + '_' + str(index) + '_' + str(
                i) + '.png')
            subject.plot(show=False, output_path=save_path)
        plt.close('all')



    def load(self, checkpoint_path, optimizer=False):
        checkpoint = torch.load(checkpoint_path)
        #checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        self.netG.load_state_dict(checkpoint['netG'])

        if optimizer:
            self.optimizer_G.load_state_dict(checkpoint['optimizer_G'])



    def eval(self):
        self.netG.eval()
        
        
    def inferencePatch(self, dataset, export_path=None, checkpoint=None, save=False):
        
        checkpoint = checkpoint or os.path.join(self.expr_dir, "latest")
        
        self.load(checkpoint)
        self.eval()

        subjects = []
        for n,subject in enumerate(dataset.subject):
            
            grid_sampler = tio.inference.GridSampler(
                subject,
                dataset.patch_size,
                dataset.patch_overlap)
            
            aggregator = tio.inference.GridAggregator(grid_sampler, 
                                                      overlap_mode='hann')
            loader = torch.utils.data.DataLoader(
                grid_sampler, batch_size=1, drop_last=True)
            
            start = time.time()
            with torch.no_grad():
                for i, data in enumerate(loader):
                    
                    b0_guidance = data['b0guidance'][tio.DATA].to(self.device)
                    b0_guidance = b0_guidance.transpose_(2, 4)
                    b0_guidance = torch.squeeze(b0_guidance, dim=2)
                    locations = data[tio.LOCATION]
                    b800_noisy = data['b800noisy'][tio.DATA].to(self.device)
                    b800_noisy = b800_noisy.transpose_(2, 4)
                    b800_noisy = torch.squeeze(b800_noisy, dim=2)
                    
                    inpt = torch.cat((b800_noisy, b0_guidance), dim=1)
                    b800_denoised = self.netG.forward(inpt)
                    
                    b800_denoised = b800_denoised[:,:,None,:,:]
                    b800_denoised = b800_denoised.transpose_(2, 4)
                                    
                    aggregator.add_batch(b800_denoised, locations)
                    
                    print(f"patch {i + 1}/{len(loader)}")
                    
                foreground = aggregator.get_output_tensor()
                affine = torch.squeeze(data['b0guidance'][tio.AFFINE],dim=0)
                b0_guidance = data['b0guidance'][tio.DATA].to(self.device)
                                
                subject = tio.Subject(
                    b800denoisy = tio.ScalarImage(tensor=foreground, affine=affine)
                    )
                    
                print(f"{time.time() - start} sec. for evaluation")
                subjects.append(subject)
            
        return subjects
    