In [1]:
from tqdm import tqdm_notebook
import SimpleITK as sitk
import numpy as np
import re
from os.path import join
from os import listdir
from os import mkdir
from os.path import exists
from random import shuffle, sample
from scipy.misc import imsave

In [2]:
from matplotlib import pyplot as plt
# Define a function to plot a batch or list of image patches in a grid
def plot_image(images, images_per_row=8):
    
    fig, axs = plt.subplots(int(np.ceil(len(images)/images_per_row)), images_per_row)
    
    c = 0
    for ax_row in axs:
        for ax in ax_row:
            if c < len(images):
                ax.imshow(images[c])
            ax.axis('off')            
            c += 1
    plt.show()

In [33]:
dataDir = "/projects/0/ismi2018/FINALPROJECTS/BREAST_THOMOSYNTHESIS"

def listCaseIDs(dataRoot):
    cases = [f.replace(".gtrt","") for f in listdir(join(dataRoot,"gtrs")) if ".gtrt" in f]
    return cases

def getLabels(caseID,dataRoot):
    filename = join(dataRoot,"gtrs",caseID+".gtrt")
    annotation = open(filename,"r").read().split('\n\n')
    regex = "\[\n[\d* \d* \d*\n+]+"
    prog = re.compile(regex)
    for ann in annotation:
        if("contour" in ann):
            annRegion = prog.findall(ann)
            annRegion = [item for r in annRegion for item in r.split("\n")[1:]]
            annRegion = [[int(c) for c in cords.split()] for cords in annRegion if len(cords.split()) == 3]
            annRegion = np.asarray(annRegion)
        elif("points" in ann):
            annPoints = prog.findall(ann)
            annPoints = [item for r in annPoints for item in r.split("\n")[1:]]
            annPoints = [[int(c) for c in cords.split()] for cords in annPoints if len(cords.split()) == 3]
            annPoints = np.asarray(annPoints)
    return annRegion, annPoints

def makeMask(caseID,dataRoot,dims=None):
    anotationFileName = join(dataRoot,"gtrs",caseID+".gtrt")
    points = getPoints(anotationFileName,dataRoot)
    if(dims is None):
        dims = np.max(points,axis=0) + 1
    Mask = np.zeros(dims)
    Mask[points[:,1],points[:,2],points[:,2]] = 1.0
    return Mask

def loadScan(caseID,dataRoot):
    dataFolder = join(dataRoot,"dataset","t" + caseID)
    images = [f for f in listdir(dataFolder) if ".dcm" in f]
    images.sort(key=lambda x: float(x.replace(".dcm","")))
    
    scan = np.zeros(getDims(caseID,dataRoot))
    for i,f in enumerate(images):
        image = sitk.ReadImage ( join(dataFolder,f) )
        scan[:,:,i] = sitk.GetArrayFromImage(image)
    return scan
    
def getDims(caseID,dataRoot):
    dataFolder = join(dataRoot,"dataset","t" + caseID)
    images = [f for f in listdir(dataFolder) if ".dcm" in f]
    image = sitk.ReadImage ( join(dataFolder,images[0]) )
    dims = image.GetSize()
    dims = (dims[1],dims[0],len(images))
    return dims

def getregionMax(regions):
    r_x = regions[:,0]
    r_y = regions[:,1]
    return np.min(r_x), np.max(r_x), np.min(r_y), np.max(r_y)

def createBinaryMask(pointCoordinates, regShape, z_value):
    binaryMask = np.zeros(regShape)
    for point in pointCoordinates:
        if(point[2] == z_value):
            binaryMask[point[0], point[1]] = 1
        else:
            binaryMask[point[0], point[1]] = -1
    return binaryMask

def saveRegions(caseID, dataRoot, regions, points, boundingSize=50):
    savedir = '.'
    cases = [f.replace(".dcm","") for f in listdir(join(dataRoot,"dataset", ''.join(['t',caseID]))) if ".dcm" in f]
    if not exists(join(savedir,"regions", "images")):
        mkdir(join(savedir,"regions"))
        mkdir(join(savedir,"regions", "images"))
        mkdir(join(savedir,"regions", "masks"))
    x_min, x_max, y_min, y_max = getregionMax(points)
    for point in list(set(points[:,2])):
        if str(point) in cases:
            
            # Extract region
            image = sitk.ReadImage(join(dataRoot,"dataset", ''.join(['t',caseID]), ''.join([str(point),'.dcm'])))
            image = sitk.GetArrayFromImage(image)
            image = image.transpose([2,1,0])
            region = image[max(0,x_min - boundingSize):min(x_max + boundingSize,image.shape[0]),max(0, y_min - boundingSize):min(y_max + boundingSize,image.shape[1]),0]
            imsave(join(savedir,"regions", "images",''.join([caseID,"_",str(point),".png"])), region)
            
            # binaryMask
            #pointList = points[points[:,2] == point,0:2]
            pointList = points - [x_min - boundingSize, y_min - boundingSize, 0]
            mask = createBinaryMask(pointList, region.shape,point)
            #mask = mask.transpose([1,0])
            imsave(join(savedir,"regions", "masks",''.join([caseID,"_",str(point),".png"])), mask)
            
            

In [34]:
ids = listCaseIDs(dataDir)
for x in ids:
    regions, points = getLabels(x,dataDir)
    saveRegions(x, dataDir, regions, points)