# Inference
You DON'T need a GPU to run inference. It is fast even in a CPU.

This notebook segments data indicating by the subject numbers in inference_input.txt. Its pre-loaded with one subjects as an example, however you need to follow the download instructions to get more.

In [1]:
!cat inference_input.txt

100206


Note that these subjects should be in the HCP_processed_data, following the organization that results from our pre-processing pipeline (for the T1 and difussion files), as in the example:

In [2]:
!tree ./Data/HCP_processed_data/100206

[01;34m./Data/HCP_processed_data/100206[00m
├── [01;34mFSL[00m
│   └── [01;31mT1_first_all_fast_firstseg_1.25_nearest.nii.gz[00m
├── [01;34mFreeSurfer[00m
│   └── [01;31maparc+aseg_1.25_nearest.nii.gz[00m
├── [01;34mQuickNAT[00m
│   └── [01;31msegmentation_acpc_dc_restore_1.25_nearest.nii.gz[00m
├── [01;34mSTAPLE[00m
│   └── [01;31mSTAPLE_th0.5_thalamus_1.25.nii.gz[00m
├── [01;32mT1w_acpc_dc_restore_1.25.nii.gz[00m
└── [01;34mdiffusion[00m
    ├── [01;31mFA.nii.gz[00m
    ├── [01;31mMD.nii.gz[00m
    ├── [01;31mRD.nii.gz[00m
    ├── [01;31mevals.nii.gz[00m
    ├── [01;31mevalue1.nii.gz[00m
    └── [01;31mevecs.nii.gz[00m

5 directories, 11 files


If you download the rest of the data from the benchmark, it should be already in that format.

In [3]:
import os
import nibabel as ni
import numpy as np
import pytorch_lightning as pl
import argparse

import torch
from CNNs.unet import UNet
from Utils.transforms import My_transforms

In [14]:
def load_nii_file(file_path):
    data = ni.load(file_path)
    volume = np.nan_to_num(data.get_data().squeeze())
    return volume

def load_files(file_paths, d_type=None):
    images = []
    for path in file_paths:
        if d_type == None:
            images.append(load_nii_file(path))
        else: 
            images.append(load_nii_file(path).astype(d_type))
    return images


def to_onehot(matrix, labels=[], single_foregound_lable=True, background_channel=True, onehot_type=np.dtype(np.float32)):
    matrix = np.around(matrix)
    if len(labels) == 0:
        labels = np.unique(matrix) 
        labels = labels[1::]
    
    mask = np.zeros(matrix.shape, dtype=onehot_type)
    for i, label in enumerate(labels):
        mask += ((matrix == label) * (i+1))
   
    if single_foregound_lable:
        mask = (mask > 0)
        labels = [1]
        
    labels_len = len(labels)        
        
    onehot = np.zeros((labels_len+1,) + matrix.shape, dtype=onehot_type) 
    for i in range(mask.max()+1):
        onehot[i] = (mask == i)  
        
    if background_channel == False:
        onehot = onehot[1::] 
        
       
    return mask, onehot, labels

In [15]:
class Segmentor(pl.LightningModule):
    def __init__(self, hparams: argparse.Namespace):
        super().__init__()

        self.save_hyperparameters(hparams)        

        if "unet" in self.hparams.cnn_architecture:
            architecture = UNet(nin_channels=self.hparams.n_inchannels, 
                                nout_channels=self.hparams.n_outchannels, 
                                init_features=self.hparams.init_features)
        elif self.hparams.cnn_architecture == "coedet":
            architecture = CoEDET(nin=self.hparams.n_inchannels, nout=self.hparams.n_outchannels, 
                                  apply_sigmoid=self.hparams.apply_sigmoid)
        else:
            raise ValueError(f"Unsupported cnn_architecture {self.hparams.cnn_architecture}")

        self.model = architecture
    
        
        ttransform_scale=None
        ttransform_angle=None
        ttransform_flip_prob=None
        ttransform_sigma=None
        ttransform_ens_treshold=None
        if "taug_scale" in self.hparams:
            ttransform_scale = self.hparams.taug_scale
        if "taug_angle" in self.hparams:
            ttransform_angle = self.hparams.taug_angle
        if "taug_flip_prob" in self.hparams:
            ttransform_flip_prob = self.hparams.taug_flip_prob
        if "taug_sigma" in self.hparams:
            ttransform_sigma = self.hparams.taug_sigma
        if "taug_ens_treshold" in self.hparams:
            ttransform_ens_treshold = self.hparams.aug_ens_treshold
        self.train_transforms = My_transforms(scale=ttransform_scale,
                                         angle=ttransform_angle,
                                         flip_prob=ttransform_flip_prob,
                                         sigma=ttransform_sigma,
                                         ens_treshold=ttransform_ens_treshold
                                        )
        vtransform_scale=None
        vtransform_angle=None
        vtransform_flip_prob=None
        vtransform_sigma=None
        vtransform_ens_treshold=None
        if "vaug_scale" in self.hparams:
            ttransform_scale = self.hparams.vaug_scale
        if "vaug_angle" in self.hparams:
            ttransform_angle = self.hparams.vaug_angle
        if "vaug_flip_prob" in self.hparams:
            ttransform_flip_prob = self.hparams.vaug_flip_prob
        if "vaug_sigma" in self.hparams:
            ttransform_sigma = self.hparams.vaug_sigma
        if "vaug_ens_treshold" in self.hparams:
            ttransform_ens_treshold = self.hparams.vaug_ens_treshold
        self.val_transforms = My_transforms(scale=vtransform_scale,
                                         angle=vtransform_angle,
                                         flip_prob=vtransform_flip_prob,
                                         sigma=vtransform_sigma,
                                         ens_treshold=vtransform_ens_treshold
                                        )
        
  

    def forward(self, x):
        return self.model(x)

    def training_step(self, train_batch, batch_idx):
        loss = None

        x, y = train_batch
        logits = self.forward(x)
        loss = CombinedLoss(logits, y, 
                            self.hparams.train_loss_funcs, 
                            self.hparams.lossweighs,
                            func_weights=self.hparams.func_weights)

        if self.hparams.train_metric == 'DiceMetric_weighs':
            train_metric = DiceMetric_weighs(y_pred=logits, y_true=y,
                                             weights=self.hparams.train_metricweighs, treshold=0.5)
        else:
            raise ValueError(f"Unsupported metric {self.hparams.train_metric}")

        self.log("loss", loss, on_epoch=True, on_step=True)
        self.log("train_metric", train_metric, on_epoch=True, on_step=False)

        return loss

    def validation_step(self, val_batch, batch_idx):
        logits = None

        x, y = val_batch
        logits = self.forward(x)
        loss = CombinedLoss(logits, y, 
                            self.hparams.val_loss_funcs, 
                            self.hparams.lossweighs,
                            func_weights=self.hparams.func_weights)
    
        if self.hparams.val_metric == 'DiceMetric_weighs':
            val_metric = DiceMetric_weighs(y_pred=logits, y_true=y,
                                             weights=self.hparams.val_metricweighs, treshold=0.5)
        else:
            raise ValueError(f"Unsupported metric {self.hparams.val_metric}")

        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_metric", val_metric, on_epoch=True, on_step=False, prog_bar=True)


        
        self.log("learning_rate_test", self.optimizer.param_groups[0]['lr'], on_epoch=True, on_step=False, prog_bar=False)

    
    def get_optimizer_by_name(self, name, lr):
        '''
        Note que você pode adicionar funções suas ao LightningModule 
        Defini essa função para poder selecionar o otimizador por uma string.
        '''
        if name == "Adam":
            return Adam(self.model.parameters(), lr=lr)
        elif name == "SGD":
            return SGD(self.model.parameters(), lr=lr)
        else:
            raise ValueError(f"Unsupported optimizer: {name}")
            

    def configure_optimizers(self):
        '''
        Select optimizer and scheduling strategy according to hparams.
        '''
        optimizer = self.get_optimizer_by_name(self.hparams.opt_name, 
                                               self.hparams.lr)

        if self.hparams.lr_decay_policy == 'step':
            scheduler = StepLR(optimizer, self.hparams.scheduling_patience_lrepochs, self.hparams.lr_decay_factor, verbose=True)
            print('STEP - scheduling_patience_lrepochs = ', self.hparams.scheduling_patience_lrepochs, ' lr_decay_factor = ', self.hparams.lr_decay_factor)
        elif self.hparams.lr_decay_policy == 'plateau':
            print('PLATEAU - scheduling_patience_lrepochs = ', self.hparams.scheduling_patience_lrepochs, ' lr_decay_factor = ', self.hparams.lr_decay_factor)

            self.optimizer = optimizer
            lr_scheduler =  {
                           'scheduler': ReduceLROnPlateau(optimizer),
                           'mode': self.hparams.lr_decay_mode,
                           'factor': self.hparams.lr_decay_factor,
                           'patience': self.hparams.scheduling_patience_lrepochs,
                           'threshold': 0.0001,
                           'threshold_mode': self.hparams.lr_decay_threshold_mode,
                           'cooldown': 0,
                           'min_lr': 0,
                           'eps': 1e-08,
                           'monitor': 'val_loss',
                           'verbose': True
                           }
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
        
         
            
        else:
            raise ValueError(f"Unsupported lr_decay_policy {self.hparams.lr_decay_policy}")
            

        return [optimizer], [scheduler]
    

In [16]:
# # #fine tuning freeze
#all channels:
axial_model_path = "./checkpoints/fine_tuning_unet_single_label_freeze/unet_single_label_psz064_axial_15-11-2021_20-29-epoch=03-val_loss=0.11.ckpt"
coronal_model_path = "./checkpoints/fine_tuning_unet_single_label_freeze/unet_single_label_psz064_coronal_15-11-2021_20-47-epoch=02-val_loss=0.11.ckpt"
sagittal_model_path = "./checkpoints/fine_tuning_unet_single_label_freeze/unet_single_label_psz064_sagittal_15-11-2021_21-48-epoch=04-val_loss=0.11.ckpt"


In [17]:
# Paths
dataset_folder = './Data/HCP_processed_data/'

dest_folder = './Predictions/fine_tuning_unet_single_label_freeze/all_channels/'
os.makedirs(dest_folder, exist_ok=True)

subject_list = 'inference_input.txt'

# Experiment definition
experiment_name = 'single_label'
Slice_views = ['axial', 'coronal', 'sagittal']


percentil_filt = 99.98
normalize_volumes = [0,1]

prediction_threshold = 0.5



evalue1_sufix = 'diffusion/evalue1.nii.gz'
FA_sufix = 'diffusion/FA.nii.gz'
RD_sufix = 'diffusion/RD.nii.gz'
MD_sufix = 'diffusion/MD.nii.gz'
T1_sufix = 'T1w_acpc_dc_restore_1.25.nii.gz'
img_paths = [evalue1_sufix, FA_sufix, RD_sufix, MD_sufix, T1_sufix]


mask_free_sufix = 'FreeSurfer/aparc+aseg_1.25_nearest.nii.gz'

save_prediction = True

input_d_type='float32'

In [18]:
%%time

subjects = [line.strip() for line in open(subject_list)]

MASKS = []
STAPLE = []
FREE = []
FSL = []
QUI = []
MAN = []
PREDICTIONS = []
PREDICTIONS_fullsize = []

for subject in subjects:
    print('subject = ', subject)
        
    images = load_files([dataset_folder + subject + '/' +  s for s in img_paths])

    if percentil_filt > 0:
        for i in range(len(images)):
            images[i][images[i] > np.percentile(images[i], percentil_filt)] = np.percentile(images[i], percentil_filt)
    
    if len(normalize_volumes) == 2:
        for i in range(len(images)):
            images[i] = images[i] * ((normalize_volumes[1]-normalize_volumes[0])/(images[i].max()-images[i].min()))
            images[i] = images[i] - images[i].min() + normalize_volumes[0]          
   
    img_crop = np.array(images)[:, :144, 15:159, :144]
    
    PREDS = []
    for Slice_view in Slice_views:
        
        # reorient images
        if Slice_view == 'axial':
            img_crop_reoriented = np.transpose(img_crop, (3, 0, 1, 2))
            model_path = axial_model_path
        elif Slice_view == 'coronal':
            img_crop_reoriented = np.transpose(img_crop, (2, 0, 1, 3))
            model_path = coronal_model_path
        elif Slice_view == 'sagittal':
            img_crop_reoriented = np.transpose(img_crop, (1, 0, 2, 3))
            model_path = sagittal_model_path

        trained_model = Segmentor.load_from_checkpoint(model_path).eval()

        with torch.no_grad():
            preds = trained_model(torch.tensor(img_crop_reoriented)).cpu().numpy()

        # reorient images
        if Slice_view == 'axial':
            preds = np.transpose(preds, (1, 2, 3, 0))
        elif Slice_view == 'coronal':
            preds = np.transpose(preds, (1, 2, 0, 3))
        elif Slice_view == 'sagittal':
            preds = np.transpose(preds, (1, 0, 2, 3))

        PREDS.append(preds) #prediction for each slice
    
    prediction = np.zeros(preds.shape)
    for pred in PREDS:
        prediction = prediction + pred/len(PREDS)
    
    if save_prediction:
        FREE_file = ni.load(dataset_folder + subject + '/' + mask_free_sufix)
        FREE_data = FREE_file.get_data()  
                            
        PREDICTION_fullsize = np.zeros(images[0].shape)
        PREDICTION_fullsize[:144, 15:159, :144] = (prediction[1] >= prediction_threshold)  #save only the thalamus channel
        PREDICTIONS_fullsize.append(PREDICTION_fullsize)
        prediction_file = ni.Nifti1Image(PREDICTION_fullsize.astype(FREE_data.dtype), affine=FREE_file.affine, header=FREE_file.header)
        ni.save(prediction_file, dest_folder + subject + '.nii.gz')
        
    PREDICTIONS.append(np.asarray(prediction))

subject =  100206



* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  volume = np.nan_to_num(data.get_data().squeeze())


CPU times: user 1min 16s, sys: 24.8 s, total: 1min 41s
Wall time: 15.2 s



* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0


# Outputs are on the Predictions folder!