# Predição no conjunto de validação 

* Para a definição do modelo com melhor resultado, foi calculado apenas o Dice médio entre as segmentação (manual ou padrão prata) e a saída de cada modelo. Para isso foi implementada uma função usando um método de inferência de janela deslizante do monai. 

## Importação das bibliotecas e das funções para a predição

In [1]:
import os
import glob
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt

from dataset_mri_3D import DatasetMRI
from data_module_3D import MRICCDataModule 

from unet_module import LightningMRICCv2

from post_processed import get_post_processed_cc3d
from predict_3D import predict_brain_model

## Leitura do conjunto de teste de validação e do peso para a predição 

In [2]:

path_Data = "../../../../.."
pre_trained_model_path = os.path.join(path_Data, "logs_brain/V2_Brain_segmentatio_normalizationepoch=55-val_loss=0.03.ckpt")
model = LightningMRICCv2.load_from_checkpoint(pre_trained_model_path).eval().cpu()
folder_mri = os.path.join(path_Data, "brain_dataset/split_dataset/teste_sep")
subjects = glob.glob(os.path.join(folder_mri, "*"))

UNet in channels: 1 batch_norm: instance dim: 3d out_channels 1 


In [3]:
vol_data, mask_data, test_outputs, dice_mean, dice_per_subject, pos_process, dice_mean_pos, dice_per_subject_pos, vol_data_affine, test_outputs_save = predict_brain_model(model, subjects)

100%|██████████| 3/3 [01:33<00:00, 31.21s/it]


In [7]:
# Dice médio antes do pós-processamento
dice_mean

0.9492049299437424

In [6]:
# Dice por sujeito antes do pós-processamento
dice_per_subject

{'A00063008': 0.9675570726394653,
 'A00062942': 0.9682649374008179,
 'A00064081': 0.9727184772491455,
 'S05': 0.9492348432540894,
 'A00062282': 0.9728366732597351,
 'A00063103': 0.9638174772262573,
 'A00062266': 0.8113753199577332,
 'CC0060': 0.9740077257156372,
 'A00063589': 0.9680245518684387,
 'CC0305': 0.9720557928085327,
 'A00062288': 0.9671775698661804,
 'CC0004': 0.9713920950889587,
 'CC0121': 0.9607912302017212,
 'S11': 0.9671435356140137,
 'A00063368': 0.9311287999153137,
 'CC0181': 0.8020511269569397,
 'CC0120': 0.9703450202941895,
 'S19': 0.9341629147529602,
 'A00062351': 0.9700304865837097,
 'CC0300': 0.9361456632614136,
 'A00063326': 0.963356614112854,
 'CC0180': 0.8336613178253174,
 'CC0276': 0.9813665151596069,
 'A00062934': 0.9711856842041016,
 'A00062917': 0.9683383107185364,
 'CC0080': 0.9770290851593018,
 'CC0240': 0.981547474861145,
 'CC0003': 0.9704250693321228,
 'S35': 0.9497715830802917}

In [4]:
# Dice Médio após o pós-processamento
dice_mean_pos

0.9698941543184477

In [5]:
# Dice por sujeito após o pós-processamento
dice_per_subject_pos

{'A00063008': 0.9676685333251953,
 'A00062942': 0.9682649374008179,
 'A00064081': 0.9727886915206909,
 'S05': 0.9593155980110168,
 'A00062282': 0.972849428653717,
 'A00063103': 0.9643991589546204,
 'A00062266': 0.9713779091835022,
 'CC0060': 0.9740698933601379,
 'A00063589': 0.9680806994438171,
 'CC0305': 0.9744942784309387,
 'A00062288': 0.9672918319702148,
 'CC0004': 0.9714128375053406,
 'CC0121': 0.9607921242713928,
 'S11': 0.9687363505363464,
 'A00063368': 0.9678525328636169,
 'CC0181': 0.973111093044281,
 'CC0120': 0.9727681875228882,
 'S19': 0.9543633460998535,
 'A00062351': 0.9706025123596191,
 'CC0300': 0.9816370606422424,
 'A00063326': 0.9637998342514038,
 'CC0180': 0.9759737253189087,
 'CC0276': 0.9816678762435913,
 'A00062934': 0.9711856842041016,
 'A00062917': 0.9696158170700073,
 'CC0080': 0.9770485758781433,
 'CC0240': 0.981574535369873,
 'CC0003': 0.9705281257629395,
 'S35': 0.9536592960357666}

'A00062282': 0.972849428653717,
 'S11': 0.9687363505363464,
  'CC0300': 0.9816370606422424,

In [None]:
plt.imshow(vol_data.cpu()[0, 0, 150, :, :], cmap="gray")
plt.show()
plt.imshow(mask_data[150, :, :], cmap="gray")
plt.show()
plt.imshow(test_outputs.detach().cpu()[150, :, :], cmap="gray")
plt.show()
plt.imshow((test_outputs.detach().cpu()[150, :, :]-mask_data[100, :, :])/2, cmap="gray")
plt.show()

In [None]:
volume_saida_bin.detach().cpu()[100, :, :].dtype

In [None]:
for i in range(112, 142):
    plt.figure(figsize=(18, 10))
    
    plt.subplot(1, 5, 1)
    plt.imshow(vol_data.cpu()[0, 0, 100, :, :], cmap="gray")
    plt.xticks([])
    plt.yticks([])  
    
    plt.subplot(1, 5, 2)
    plt.imshow(mask_data[100, :, :], cmap="gray", vmin=0, vmax=1)
    plt.xticks([])
    plt.yticks([])
    
    plt.subplot(1, 5, 3)
    plt.imshow(volume_saida_bin.detach().cpu()[100, :, :], cmap="gray", vmin=0, vmax=1)
    plt.xticks([])
    plt.yticks([])
    
    plt.subplot(1, 5, 4)
    plt.imshow((volume_saida_bin.detach().cpu()[100, :, :]+vol_data.cpu()[0, 0, 100, :, :])/2, cmap="gray", vmin=0, vmax=1)
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 5, 5)
    plt.imshow((mask_data[100, :, :]+vol_data.cpu()[0, 0, 100, :, :].numpy())/2, cmap="gray", vmin=0, vmax=1)
    plt.xticks([])
    plt.yticks([])
    
    plt.show()