In [7]:
import json
import numpy as np
import cv2 
import SimpleITK as sitk 
from collections import OrderedDict
from geotnf.transformation import GeometricTnf
import torch 
from image.normalization import NormalizeImageDict, normalize_image
import os 
from matplotlib import pyplot as plt
from parse_registration_json import ParserRegistrationJson
from parse_study_dict import ParserStudyDict

In [92]:
def preprocess_mri(fixed_img_mha, fixed_seg, pre_process_fixed_dest, coord, case):     
    imMri       = sitk.ReadImage(fixed_img_mha)
    imMri       = sitk.GetArrayFromImage(imMri)
    try:
        imMri.shape[2]
    except:
        imMri = imMri.reshape(1,imMri.shape[0],imMri.shape[1])

    imMriMask   = sitk.ReadImage(fixed_seg)
    imMriMaskArray = sitk.GetArrayFromImage(imMriMask)
    try:
        imMriMaskArray.shape[2]
    except:
        imMriMaskArray = imMriMaskArray.reshape(1,imMriMaskArray.shape[0],imMriMaskArray.shape[1])
    
    #### resample mri mask to be the same size as mri
    if (imMri.shape[1] != imMriMaskArray.shape[1] or imMri.shape[2] != imMriMaskArray.shape[2]):
        mri_ori   = sitk.ReadImage(fixed_img_mha)
        resampler = sitk.ResampleImageFilter()
        resampler.SetReferenceImage(mri_ori)
        imMriMask = resampler.Execute(imMriMask)
        print("input mri and mri mask have different sizes")
    
    imMriMask = sitk.GetArrayFromImage(imMriMask)
    try:
        imMriMask.shape[2]
    except:
        imMriMask = imMriMask.reshape(1,imMriMask.shape[0],imMriMask.shape[1])
    
    coord[case] = {}
    coord[case]['x_offset'] = []
    coord[case]['y_offset'] = []
    coord[case]['x'] = []
    coord[case]['y'] = []
    coord[case]['h'] = []
    coord[case]['w'] = []
    coord[case]['slice']  = []
    
    for slice in range(imMri.shape[0]):
        if np.sum(np.ndarray.flatten(imMriMask[slice, :, :])) == 0: 
            continue
        
        mri = imMri[slice, :, :]*imMriMask[slice, :, :]
        
        mri_mask = imMriMask[slice, :, :] 
        if np.amax(mri_mask) == 1:
            mri_mask *= 255
        
        # create a bounding box around slice
        points = np.argwhere(mri_mask != 0)
        points = np.fliplr(points) # store them in x,y coordinates instead of row,col indices
        y, x, h, w = cv2.boundingRect(points) # create a rectangle around those points
        
        imMri[slice, :, :] = imMri[slice, :, :] / int(np.max(imMri[slice, :, :]) / 255)
        
        if h>w:
            y_offset = int(h*0.15)
            x_offset = int((h - w + 2*y_offset)/2)
        else:
            y_offset = int(h*0.2)
            x_offset = int((h - w + 2*y_offset)/2)
        
        coord[case]['x'].append(x)
        coord[case]['y'].append(y)
        coord[case]['h'].append(h)
        coord[case]['w'].append(w)
        coord[case]['slice'].append(slice) 
        coord[case]['x_offset'].append(x_offset)
        coord[case]['y_offset'].append(y_offset)  
        
        if x - x_offset < 0:
            min_x = 0
        else:
            min_x = x - x_offset
            
        if y - y_offset < 0:
            min_y = 0
        else:
            min_y = y - y_offset
        crop = mri[min_x:x+w+x_offset, min_y:y+h +y_offset]
        
        h = h + 2*y_offset
        w = w + 2*x_offset
        
        crop = crop*25.5/(np.max(crop)/10)
        
        # upsample slice to approx 500 px in width
        ups = 1; 
        upsHeight = int(h*ups)
        upsWidth = int(w*ups)
        
        upsMri = cv2.resize(crop.astype('float32'), (upsHeight,  upsWidth), interpolation=cv2.INTER_CUBIC)
        
        # save x, y, x_offset, y_offset, h, w for each slice in dictionary 'coord' (coordinates)
        
        try: 
            os.mkdir(pre_process_fixed_dest + case)
        except: 
            pass 
        
        # write to a file        
        cv2.imwrite(pre_process_fixed_dest + case + '/mri_' + case + '_' + str(slice).zfill(2) +'.jpg', upsMri)  
        cv2.imwrite(pre_process_fixed_dest + case + '/mriUncropped_' + case + '_' + str(slice).zfill(2) +'.jpg', imMri[slice, :, :])
        cv2.imwrite(pre_process_fixed_dest + case + '/mriMask_' + case + '_' + str(slice).zfill(2) +'.jpg', np.uint8(mri_mask))

    coord = OrderedDict(coord)
    
    return coord

In [94]:

try:
    with open('coord.txt') as f:
        coord = json.load(f)    
except:
    coord = {}

json_obj    = ParserRegistrationJson("jsonData/TCIA_FUSION.json")

cases       = json_obj.ToProcess.keys()
outputPath  = json_obj.output_path

preprocess_moving_dest = outputPath + '/preprocess/hist/'
preprocess_fixed_dest = outputPath + '/preprocess/mri/'

print(preprocess_fixed_dest)

# start doing preprocessing on each case and register
for s in json_obj.studies:
    print("x"*30, "Processing", s,"x"*30)
    studyDict = json_obj.studies[s] 

    studyParser = ParserStudyDict(studyDict)

    sid             = studyParser.id
    fixed_img_mha   = studyParser.fixed_filename
    fixed_seg       = studyParser.fixed_segmentation_filename
    moving_dict     = studyParser.ReadMovingImage()
    
    coord = preprocess_mri(fixed_img_mha, fixed_seg, preprocess_fixed_dest, coord, sid)
    
with open('coord.txt', 'w') as json_file: 
    json.dump(coord, json_file)

'do_affine'
'do_deformable'
'do_reconstruction'
'fast_execution'
'use_imaging_constraints'
Reading aaa0069 Study Json ./jsonData/reg_aaa0069.json
Reading HMU_010_FH Study Json ./jsonData/reg_HMU_010_FH.json
./results//preprocess/mri/
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx Processing aaa0069 xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx Processing HMU_010_FH xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
input mri and mri mask have different sizes
