In [None]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

from extraction import extract_patches
from reconstruction import perform_voting, generate_indexes

from model import Multimodel

In [None]:
from keras import backend as K
import tensorflow as tf
config_tf = tf.ConfigProto()
config_tf.gpu_options.allow_growth=True
K.set_session(tf.Session(config=config_tf))

In [None]:
def computing_usefull_paches(img_filename, curr_patch_shape, step, threshold):        
    img = nib.load(img_filename).get_data()    
    img=np.rollaxis(img,2,0)        
    mask_patches = extract_patches(img!=0, (1, )+ curr_patch_shape  , (1, )+ step)    
    useful_patches = np.sum(mask_patches, axis=(1, 2, 3)) > threshold       
    del mask_patches
    return  useful_patches

In [None]:
import os
import sys
import nibabel as nib
%matplotlib inline


#dataset path
dataset_path = '/home/mostafasalem/AllData/New_Generator_VH_ISBI/training'

#Masks patterns
mask_pattern = dataset_path+'/{0}/W{1}.nii.gz'
num_masks=8

#Images patterns
basalImg_pattern  = dataset_path+'/{0}/{1}_normalized_filled_WMHIM_smoothed.nii.gz'
followupImg_pattern= dataset_path+'/{0}/{1}_normalized.nii.gz' 

#Modalities &Step size & Patch shape
#modalities=['t1','t2','pd','flair']
#modalities=['T1', 'FLAIR']
modalities=['t1', 'flair']

step = (32, 32)
curr_patch_shape = (64, 64)

#Lesions and healthy cases
training_patients = sorted([f for f in os.listdir(dataset_path)])    

#Train in&out data
in_train={}
out_train={}

for m in modalities:
    in_train[m] = np.empty((0, 9, ) + curr_patch_shape)
    out_train[m] = np.empty((0, 1, ) + curr_patch_shape)

#Patches threshold
threshold = np.int32(0.0 * np.prod(curr_patch_shape[:]))

#Reading the images
for p in training_patients:
    print 'Reading case No. {} :'.format(p)
    
    #Computing the usefull patches
    print '\tComputing the usefull paches'
    useful_patches = computing_usefull_paches(basalImg_pattern.format(p,'t1'), curr_patch_shape, step, threshold)    
    N = np.sum(useful_patches)
    
    # Reading the basal mask(s) that will be added as a second channel to all the modality encoders    
    allmasks_patches = np.empty((N, 8, ) + curr_patch_shape)
    print '\tExtracting the mask patches'
    for m in range(1,num_masks+1):        
        mask_filename = mask_pattern.format(p,m)
        mask = nib.load(mask_filename).get_data()
        mask=np.rollaxis(mask,2,0)
        mask_patches = extract_patches(mask, (1, ) + curr_patch_shape, (1, ) + step)        
        mask_patches = mask_patches[useful_patches].reshape((-1, ) + curr_patch_shape)        
        allmasks_patches[:, m-1] = mask_patches
    del mask_patches
    # The modalities encoder
    print '\tExtracting the modalities patches'
    for m in modalities:
        #Reading the basal modalities
        print '\t{0} basal patches'.format(m)
        basalImg_filename = basalImg_pattern.format(p, m)
        basalImg = nib.load(basalImg_filename).get_data()    
        basalImg=np.rollaxis(basalImg,2,0)                
        #basalImg = (basalImg - basalImg.mean()) / basalImg.std()    
        modality_patches = extract_patches(basalImg, (1, ) + curr_patch_shape, (1, ) + step)
        modality_patches = modality_patches[useful_patches].reshape((-1, 1, ) + curr_patch_shape)        
        in_train[m] = np.vstack((np.hstack((modality_patches, allmasks_patches)), in_train[m]))               
        del modality_patches
    del allmasks_patches
    
    for m in modalities:    
        #Reading the followup modalities
        print '\t{0} followup patches'.format(m)
        followupImg_filename = followupImg_pattern.format(p, m)
        followupImg = nib.load(followupImg_filename).get_data()    
        followupImg=np.rollaxis(followupImg,2,0)            
        #followupImg = (followupImg - followupImg.mean()) / followupImg.std()    
        modality_patches = extract_patches(followupImg, (1, ) + curr_patch_shape, (1, ) + step)
        modality_patches = modality_patches[useful_patches].reshape((-1, 1, ) + curr_patch_shape)        
        out_train[m] = np.vstack((modality_patches, out_train[m]))
        del modality_patches        
    
    print    
print


In [None]:
print "Inputs"
for key in in_train:
    print "Size of input modality {} :".format(key)
    print in_train[key].shape
    
print
print "Outputs"
for key in out_train:
    print "Size of output modality {} :".format(key)
    print out_train[key].shape


In [None]:
curr_patch_shape = (64, 64)
input_modalities = ['T1', 'FLAIR']
output_modalities = ['T1_Gen', 'FLAIR_Gen']
output_weights = {'T1_Gen' :1.0 ,'FLAIR_Gen' : 1.0, 'concat' : 1.0}
latent_dim = 32
channels = [9, 9]
use_dropout = [False, False]
patch_shape = curr_patch_shape
scale = 1
lesionsGen_model = Multimodel(
    input_modalities, output_modalities, output_weights, latent_dim, channels, patch_shape, use_dropout, scale)
lesionsGen_model.build()

In [None]:
from keras.callbacks import EarlyStopping, ModelCheckpoint

#patience = 10

#stopper = EarlyStopping(patience=patience)
checkpointer = ModelCheckpoint('models/model_WMHIGen_UsingVHISBISmoothed.h5', save_best_only=True, save_weights_only=True)

N = len(in_train[modalities[0]])

lesionsGen_model.model.fit(
[in_train[modalities[0]], in_train[modalities[1]]],
[out_train[modalities[0]],out_train[modalities[0]],out_train[modalities[0]],
 out_train[modalities[1]],out_train[modalities[1]],out_train[modalities[1]], 
np.empty((N, 2, 0)), np.empty((N, 1, 0))],
validation_split=0.3, epochs=70,
verbose=2,
callbacks=[checkpointer])

In [None]:
lesionsGen_model.model.load_weights('models/model_WMHIGen_UsingVHISBISmoothed.h5')

##Testing on patient images (Using 8 WMHIMask)(Patients or Healthies)

In [None]:
import os
import sys
import nibabel as nib
%matplotlib inline

#dataset path
dataset_path = '/home/mostafasalem/AllData/New_Generator_VH/VSI_test_linear'
#Masks patterns
mask_pattern = dataset_path+'/{0}/W{1}.nii.gz'
num_masks=8

#Images patterns
brainmask_pattern = dataset_path+'/{0}/brainmask.nii.gz'
basalImg_pattern  = dataset_path+'/{0}/{1}_normalized_filled_WMHIM_smoothed.nii.gz'
#basalImg_pattern  = dataset_path+'/{0}/{1}_filled_WMHIM_smoothed.nii.gz'

generatedDir_pattern = dataset_path+'/{0}/generated_UsingVH' 
generatedImg_pattern = dataset_path+'/{0}/generated_UsingVH/{1}_gen.nii.gz' 


#Modalities &Step size & Patch shape
#modalities=['t1','t2','pd','flair']
#modalities=['T1', 'FLAIR']
modalities=['t1', 'flair']
step = (16, 16)
curr_patch_shape = (64, 64)

#Lesions and healthy cases
testing_patients = sorted([f for f in os.listdir(dataset_path)])    

#Train in&out data
in_test={}

#Reading the images
for p in testing_patients:
    print 'Reading case No. {} :'.format(p)
    
    if os.path.isfile(generatedImg_pattern.format(p, modalities[1])):
            print '\tCase no. {} is already computed...'.format(p)
            continue  
            
    
    generatedDir = generatedDir_pattern.format(p)
    if not os.path.exists(generatedDir):
        os.makedirs(generatedDir)
    
    
    #Out volumes mask&affine    
    out_vol_shape={}
    out_vol_affine={}
    
    
    #Read the brain mask in case of skull stripping
    brainmask_filename = brainmask_pattern.format(p)    
    brainmask = nib.load(brainmask_filename).get_data()    
    brainmask=np.rollaxis(brainmask,2,0)    
    brainmask_patches = extract_patches(brainmask, (1, ) + curr_patch_shape, (1, ) + step)
    brainmask_patches = brainmask_patches.reshape((-1, 1, ) + curr_patch_shape)                
    N = brainmask_patches.shape[0]    
    del brainmask_patches
    
    # Reading the basal mask(s) that will be added as a second channel to all the modality encoders    
    allmasks_patches = np.empty((N, 8, ) + curr_patch_shape)
    print '\tExtracting the mask patches'
    for m in range(1,num_masks+1):        
        mask_filename = mask_pattern.format(p,m)
        mask = nib.load(mask_filename).get_data()
        mask=np.rollaxis(mask,2,0)
        mask_patches = extract_patches(mask, (1, ) + curr_patch_shape, (1, ) + step)        
        mask_patches = mask_patches.reshape((-1, ) + curr_patch_shape)        
        allmasks_patches[:, m-1] = mask_patches
    
    # The modalities encoder
    print '\tExtracting the modalities patches'
    for m in modalities:
        #Reading the basal modalities
        print '\t{0} basal patches'.format(m)
        basalImg_filename = basalImg_pattern.format(p, m)
        basalImg_data = nib.load(basalImg_filename)
        basalImg = basalImg_data.get_data()
        out_vol_affine[m] = basalImg_data.affine
        basalImg=np.rollaxis(basalImg,2,0)                
        out_vol_shape[m]=basalImg.shape        
        #basalImg = (basalImg - basalImg.mean()) / basalImg.std()    
        modality_patches = extract_patches(basalImg, (1, ) + curr_patch_shape, (1, ) + step)
        modality_patches = modality_patches.reshape((-1, 1, ) + curr_patch_shape)                
        in_test[m] = np.hstack((modality_patches, allmasks_patches))
        del modality_patches        
    print
    del allmasks_patches        
    
    ## The prediction using the fake mask    
    print '\tThe prediction using the lesions mask '
    preds = lesionsGen_model.model.predict([in_test[modalities[0]], in_test[modalities[1]]], verbose=2)            
    for k,m in enumerate(modalities):
        print '\tSaving modality {0} at index {1}'.format(m, k*3 + 2)                
        volume = perform_voting(preds[k*3 + 2].reshape((-1, 1, ) + curr_patch_shape), (1, ) + curr_patch_shape, out_vol_shape[m], (1, ) + step)
        volume[brainmask==0]=0
        volume = np.rollaxis(volume,0,3)
        
        nib.save(nib.Nifti1Image(volume, out_vol_affine[m]), generatedImg_pattern.format(p, m))                            
    del preds