# **Generate OB Maks**
<hr>
<br><br>

In [1]:
import nibabel as nib
import numpy as np
import os
import random
from glob import glob
import fnmatch
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.preprocessing import image
import tensorflow.keras.backend as K

%matplotlib inline

In [2]:
#library versions

#!pip list

nib.__version__

'3.2.1'

In [3]:
np.__version__

'1.19.2'

In [4]:
pd.__version__

'1.3.0'

In [16]:
tf.__version__

'2.5.0'

In [6]:
#para melhor gestão memória GPU por parte do tensorflow

def set_keras_backend(backend):
    print("A acertar o backend e libertar memória da grafica")
    if K.backend() != backend:
        os.environ['KERAS_BACKEND'] = backend
        importlib.reload(K)
        assert K.backend() == backend
    if backend == "tensorflow":
        # Allow memory growth for the GPU
        physical_devices = tf.config.experimental.list_physical_devices('GPU')
        print(physical_devices)
        tf.config.experimental.set_memory_growth(physical_devices[0], True)  

    
def limit_mem():
    K.get_session().close()
    cfg = K.tf.ConfigProto()
    cfg.gpu_options.allow_growth = True
    K.set_session(K.tf.Session(config=cfg))

#set_keras_backend("tensorflow")

In [7]:
#load models

#model to classify images with bulb or without bulb
model_c = load_model("modelo_CNN_treinado_v4.h5") 

#model to calculate centers of images with bulb                  
model_r = load_model("modelo_CNN_R_treinado_v4_174x190.h5")

#segmentation model
model_s = load_model("modelo_FCN_segmentation_32x32.h5" , compile=False)

In [8]:
scan_list=[]

def create_dataset_list(dataPath):
    scansPath = glob(dataPath+'*.gz', recursive=True)
    for file in scansPath:
        scan_list.append(file)
    return sorted(scan_list)
    
dataPath = './input/'
scan_path_list = create_dataset_list(dataPath)
print("NUMBER OF SCANS:", len(scan_path_list))
print("File Names:", scan_path_list)
print("File Name:", scan_path_list[0][10:-7])

NUMBER OF SCANS: 13
File Names: ['./input\\00912_005.nii.gz', './input\\00912_006.nii.gz', './input\\00912_007.nii.gz', './input\\00912_010.nii.gz', './input\\00912_011.nii.gz', './input\\00912_012.nii.gz', './input\\00912_013.nii.gz', './input\\00912_014.nii.gz', './input\\00912_016.nii.gz', './input\\00912_017.nii.gz', './input\\00912_018.nii.gz', './input\\00912_019.nii.gz', './input\\00912_020.nii.gz']
File Name: 912_005


In [9]:
def generate_np_for_resizing(ficheiro_scan):
    num_imagem=0
    image_list = []
    scan = nib.load(ficheiro_scan)
    scan = np.flip(scan.get_fdata().T)
    snp=np.asarray(scan)
    #normalize scan
    snp = np.floor(snp)
    snp /= np.max(snp)  
    for img in snp :
        image_list.append(img)
    image_list = np.array(image_list)
    image_list = image_list.reshape(image_list.shape[0],image_list.shape[1],image_list.shape[2],1)
    return image_list


def generate_all_scan_for_resizing(scan_path_list):
    num_imagem=0
    all_scan_list = []
    for ficheiro_scan in scan_path_list:
        list_img_scan= generate_np_for_resizing(ficheiro_scan)
        all_scan_list.append(list_img_scan)
    all_scan_list = np.array(all_scan_list)
    return all_scan_list

       


np_original_all_scans = generate_all_scan_for_resizing(scan_path_list)

print("np_original_scans.shape:",np_original_all_scans.shape)

np_original_scans.shape: (13, 30, 174, 190, 1)


**Dataset Visualization** 

In [10]:
def visualize_scan(scan):
    num_img=0
    for im in scan:
        if num_img > 3:
            break
        print(num_img)
        plt.axis('on')
        plt.grid(False)
        plt.title("img:%d"%(num_img))
        plt.imshow(im.reshape(im.shape[0],im.shape[1]))
        plt.show()
        num_img+=1
    
def visualize_dataset(images):
    num_scan = random.randint(0, images.shape[0]-1)
    num_img=0
    visualize_scan(images[num_scan])

    
#visualize_dataset(np_original_all_scans)

In [11]:
def visualize_predictions_bulbnobulb_scan(model,scan):
    num_img=0
    for im in scan:
        print(num_img)
        im2 = im.reshape(1,im.shape[0],im.shape[1],im.shape[2])
        y_predictions=model.predict(im2)
        plt.axis('on')
        plt.grid(False)
        plt.title("predict:%f"%(y_predictions))
        plt.imshow(im.reshape(im.shape[0],im.shape[1]))
        plt.show() 
        num_img+=1
            
            
def visualize_predictions_bulbnobulb(model,images):
    num_scan = random.randint(0, images.shape[0])
    num_img=0
    #visualize_predictions_bulbnobulb_scan(model,images[num_scan])
    return num_scan
            
num_scan=visualize_predictions_bulbnobulb(model_c,np_original_all_scans)  


In [12]:
def calc_i_f(n,med,nbb):
    l_half=nbb//2
    r_half=nbb-l_half
    if med-l_half<0:
        l_half=med
        r_half=nbb-l_half
    elif med+r_half>n:
        r_half = n-med
        l_half = nbb-r_half
    vi = med - l_half
    vf = med + r_half
    return int(vi), int(vf)

def process_centers(model_r, scan):
    n =scan.shape[0]
    scan_w = scan.shape[1]
    scan_h = scan.shape[2]
    xp_sum=yp_sum=slicep_sum = 0
    for i in range(n):
        im = scan[i]
        im2 = im.reshape(1,im.shape[0],im.shape[1],im.shape[2])
        p=model_r.predict(im2)
        xp,yp = p[0]
        xp_sum+=xp
        yp_sum+=yp
    xp_med = xp_sum//n
    yp_med = yp_sum//n
    xi,xf = calc_i_f(scan_w,xp_med,BB_IMAGE_WIDTH)
    yi,yf = calc_i_f(scan_h,yp_med,BB_IMAGE_HEIGHT)    
    return  xi,xf,yi,yf 
    

def process_bulnobulb(model_c, scan):
    folga=4
    n =scan.shape[0]
    slicep_sum = 0
    qtp_with_bulbo = 0
    for i in range(n):
        im = scan[i]
        im2 = im.reshape(1,im.shape[0],im.shape[1],im.shape[2])
        scan_has_bulb = model_c.predict(im2) 
        if scan_has_bulb[0] >0.2:
            qtp_with_bulbo+=1
            slicep_sum+=i  
    if qtp_with_bulbo<=1:
        slicep_med = 0
        si=sf=0
    else:
        slicep_med = slicep_sum//qtp_with_bulbo
        if qtp_with_bulbo+folga > n:
            profundidade=n
        else:
            profundidade=qtp_with_bulbo+folga
        si,sf = calc_i_f(n,slicep_med,profundidade)
    return si, sf
    

def generate_BB_scan(model_c, model_r, scan):
    print("Scan shape:",scan.shape)
    si,sf = process_bulnobulb(model_c, scan)
    print("si:%d sf:%d"%(si,sf))
    if si==sf :
        print("Não Tem Bulb")
        return None
    else:
        xi,xf,yi,yf = process_centers(model_r, scan[si:sf,:,:])  
    return scan[si:sf,xi:xf,yi:yf], (si,sf,xi,xf,yi,yf)


def find_mean_centers(model_c, model_r,all_scans):
    num_scan = random.randint(0, all_scans.shape[0])
    num_img=0
    find_mean_centers_scan(model_c,all_scans[num_scan])
    
    
BB_IMAGE_HEIGHT = 32 
BB_IMAGE_WIDTH = 32     
print(np_original_all_scans.shape)
scan_BB, coord_BB= generate_BB_scan(model_c,model_r,np_original_all_scans[num_scan])
print("BB:",scan_BB.shape )  

(13, 30, 174, 190, 1)
Scan shape: (30, 174, 190, 1)
si:7 sf:20
BB: (13, 32, 32, 1)


In [13]:
def visualize_BB(scan):
    num_img=0
    for im in scan:
        print(num_img)
        plt.axis('on')
        plt.grid(False)
        plt.imshow(im.reshape(im.shape[0],im.shape[1]))
        plt.show() 
        num_img+=1
            
            
#num_scan=visualize_BB(scan_BB)

In [14]:
def segmentation_pred_scan(model_s, scan):
    n =scan.shape[0]
    for i in range(n):
        im = scan[i]
        im2 = im.reshape(1,im.shape[0],im.shape[1],im.shape[2])
        pred = model_s.predict(im2)
        fig = plt.figure()
        plt.subplot(1, 2, 1)
        plt.suptitle('Image / Mask Predicted')
        plt.imshow(scan[i,:,:,0], cmap='gray')
        plt.subplot(1, 2, 2)
        plt.imshow(pred[0,:,:,0], cmap='gray')
        #plt.show()
        
        
def segmentation_pred_all_scan(model_s,all_scans):
    num_scan = random.randint(0, images.shape[0])
    num_img=0
    #segmentation_pred_scan(model_s,all_scans[num_scan])
    return scan

    
seg_BB=segmentation_pred_scan(model_s,scan_BB)    

In [15]:
#generate mask for nii scan

def get_affine(ficheiro_scan):
    scan = nib.load(ficheiro_scan)
    mask_affine=scan.affine
    mask_header=scan.header
    return mask_affine,mask_header

def predict_pos_processing(mask,threshold):
    return np.where(mask > threshold, 1., 0.)


def generate_mask(model_s,scan_BB, shape_scan,coord_BB):
    (si,sf,xi,xf,yi,yf) = coord_BB
    masks=np.zeros(shape_scan)
    mask_BB=np.zeros(scan_BB.shape)
    n =scan_BB.shape[0]
    for i in range(n):
        im = scan_BB[i]
        im2 = im.reshape(1,im.shape[0],im.shape[1],im.shape[2])
        pred = predict_pos_processing(model_s.predict(im2), 0.7)
        mask_BB[i]=pred[0]
    masks[si:sf,xi:xf,yi:yf]=mask_BB
    return masks



def visualization_scan_mask(scan, masks):
    n =scan.shape[0]
    for i in range(n):
        f,ax = plt.subplots(1,2,figsize=(15,5))
        for j in range(2):
            ax[j].axis('on')
            ax[j].grid(False)
            ax[j].title.set_text("img:%d "%(i))
        im = scan[i]
        mask = masks[i]
        ax[0].imshow(im.reshape(im.shape[0],im.shape[1]))
        ax[1].imshow(mask.reshape(mask.shape[0],mask.shape[1]))
    plt.show()
    
def save_mask(ficheiro_scan, masks, maskPath):
    filename= ficheiro_scan[10:-7]
    print(filename)
    mask_affine, mask_header=get_affine(ficheiro_scan)
    print(masks.shape)
    masks = masks.reshape(masks.shape[0],masks.shape[1],masks.shape[2])
    mask = np.flip(masks).T
    print(mask.shape)
    seg_mask = nib.Nifti1Image(mask, mask_affine, mask_header)
    nib.save(seg_mask, maskPath+ "/" + filename + "_mask.nii.gz")
    

MASKPATH = 'output/'

def save_all_masks():
    scan_number=np_original_all_scans.shape[0]
    scan_index = 0
    while scan_index < scan_number:

        num_scan=scan_index
        print("num_scan", num_scan)
        print(np_original_all_scans.shape)
        scan_BB, coord_BB= generate_BB_scan(model_c,model_r,np_original_all_scans[num_scan])
        print("BB:",scan_BB.shape )
        shape_scan=np_original_all_scans[num_scan].shape
        print("shape scan:",shape_scan )
        masks = generate_mask(model_s,scan_BB, shape_scan, coord_BB)  
        print("shape mask:",masks.shape )
        #visualization_scan_mask(np_original_all_scans[num_scan], masks)

        #print(scan_path_list[scan_index])
        save_mask(scan_path_list[scan_index],masks, MASKPATH)

        print(scan_index)
        scan_index += 1
        
save_all_masks()
    

num_scan 0
(13, 30, 174, 190, 1)
Scan shape: (30, 174, 190, 1)
si:5 sf:15
BB: (10, 32, 32, 1)
shape scan: (30, 174, 190, 1)
shape mask: (30, 174, 190, 1)
912_005
(30, 174, 190, 1)
(190, 174, 30)
0
num_scan 1
(13, 30, 174, 190, 1)
Scan shape: (30, 174, 190, 1)
si:3 sf:18
BB: (15, 32, 32, 1)
shape scan: (30, 174, 190, 1)
shape mask: (30, 174, 190, 1)
912_006
(30, 174, 190, 1)
(190, 174, 30)
1
num_scan 2
(13, 30, 174, 190, 1)
Scan shape: (30, 174, 190, 1)
si:7 sf:22
BB: (15, 32, 32, 1)
shape scan: (30, 174, 190, 1)
shape mask: (30, 174, 190, 1)
912_007
(30, 174, 190, 1)
(190, 174, 30)
2
num_scan 3
(13, 30, 174, 190, 1)
Scan shape: (30, 174, 190, 1)
si:7 sf:20
BB: (13, 32, 32, 1)
shape scan: (30, 174, 190, 1)
shape mask: (30, 174, 190, 1)
912_010
(30, 174, 190, 1)
(190, 174, 30)
3
num_scan 4
(13, 30, 174, 190, 1)
Scan shape: (30, 174, 190, 1)
si:8 sf:19
BB: (11, 32, 32, 1)
shape scan: (30, 174, 190, 1)
shape mask: (30, 174, 190, 1)
912_011
(30, 174, 190, 1)
(190, 174, 30)
4
num_scan 5
(13,