In [1]:
from tqdm import tqdm_notebook
import SimpleITK as sitk
import numpy as np
import re
from os.path import join
from os import listdir
from random import shuffle, sample

In [2]:
from matplotlib import pyplot as plt
# Define a function to plot a batch or list of image patches in a grid
def plot_image(images, images_per_row=8):
    
    fig, axs = plt.subplots(int(np.ceil(len(images)/images_per_row)), images_per_row)
    
    c = 0
    for ax_row in axs:
        for ax in ax_row:
            if c < len(images):
                ax.imshow(images[c])
            ax.axis('off')            
            c += 1
    plt.show()

In [3]:
dataDir = "/projects/0/ismi2018/FINALPROJECTS/BREAST_THOMOSYNTHESIS"

def listCaseIDs(dataRoot):
    cases = [f.replace(".gtrt","") for f in listdir(join(dataRoot,"gtrs")) if ".gtrt" in f]
    return cases

def getPoints(filename,dataRoot):
    annotation = open(filename,"r").read()
    regex = "\[\n[\d* \d* \d*\n+]+"
    prog = re.compile(regex)
    result = prog.findall(annotation)
    
    result = [item for r in result for item in r.split("\n")[1:]]
    result = [[int(c) for c in cords.split()] for cords in result if len(cords.split()) == 3]
    result = np.asarray(result)
    return result

def makeMask(caseID,dataRoot,dims=None,border=(0,0,0)):
    border = np.array(border)
    anotationFileName = join(dataRoot,"gtrs",caseID+".gtrt")
    points = getPoints(anotationFileName,dataRoot)
    if(dims is None):
        dims = np.max(points,axis=0) + 1
    else:
        dims = np.array(dims)
    Mask = np.zeros(dims + border)
    Mask[points[:,1]+border[0],points[:,2]+border[1],points[:,2]+border[2]] = 1.0
    return Mask

def loadScan(caseID,dataRoot,border=(0,0,0)):
    dataFolder = join(dataRoot,"dataset","t" + caseID)
    images = [f for f in listdir(dataFolder) if ".dcm" in f]
    images.sort(key=lambda x: float(x.replace(".dcm","")))
    
    imageDim = getDims(caseID,dataRoot)
    border = np.array(border)
    
    scan = np.zeros(getDims(caseID,dataRoot) + border*2)
    for i,f in enumerate(images):
        image = sitk.ReadImage ( join(dataFolder,f) )
        scan[border[0]:border[0]+imageDim[0],border[1]:border[1]+imageDim[1],i+border[2]] = sitk.GetArrayFromImage(image)
    return scan
    
def getDims(caseID,dataRoot):
    dataFolder = join(dataRoot,"dataset","t" + caseID)
    images = [f for f in listdir(dataFolder) if ".dcm" in f]
    image = sitk.ReadImage ( join(dataFolder,images[0]) )
    dims = image.GetSize()
    dims = (dims[1],dims[0],len(images))
    return dims

In [None]:
caseID = "0200710601cl"
scan = loadScan(caseID,dataDir,border=(100,100,100))
mask = makeMask(caseID,dataDir,dims=scan.shape)
print(scan.shape)
print(mask.shape)

In [4]:
def getPatch(location,scan,patchSize=(101,101,1),reduceDim=True):
    #did you give the cordinates and the patch size in atleast three dimentions
    assert(len(location)>2 and len(patchSize)>2 )
    
    halfPatch = [int((p-1)/2) for p in patchSize]
    xmin,xmax = location[0]-halfPatch[0],location[0]+halfPatch[0]+1
    ymin,ymax = location[1]-halfPatch[1],location[1]+halfPatch[1]+1
    zmin,zmax = location[2]-halfPatch[2],location[2]+halfPatch[2]+1

    #is the entire patch within the bounds of the scan
    assert(xmin>=0 and xmax < scan.shape[0] and xmin <= xmax)
    assert(ymin>=0 and ymax < scan.shape[1] and ymin <= ymax)
    assert(zmin>=0 and zmax < scan.shape[2] and zmin <= zmax)
    
    patch = scan[xmin:xmax,ymin:ymax,zmin:zmax]
    
    #if a dimention is only 1 in size and reduce dimentions is true remove the dimention
    if(reduceDim and patchSize[0]==1):
        patch = patch[0,:,:]
    if(reduceDim and patchSize[1]==1):
        patch = patch[:,0,:]
    if(reduceDim and patchSize[2]==1):
        patch = patch[:,:,0]
    
    return patch

In [17]:
class PatchGenerator(object):
    
    def __init__(self,dataDir,batch_size,n_batches, patch_size,augmentation_fn=None):
        self.dataDir=dataDir
        self.batch_size = batch_size
        self.patch_size = patch_size
        #ignore an part of the border so you dont get half patches
        self.border = [int((p-1)/2) for p in self.patch_size]
        
        self.samplesPerClass = int(self.batch_size/2)
        
        self.n_batches = n_batches
        
        self.augmentation_fn = augmentation_fn
        
        self.cases = listCaseIDs(self.dataDir)
        shuffle(self.cases)
        self.scan = None
        self.mask = None
        self.TP = []
        self.TN = []
    
    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def __len__(self):
        # Provide length in number of batches
        return self.n_batches
    
    def loadScan(self,caseId):
        self.scan = loadScan(caseId,self.dataDir,border=self.border)
        self.mask = makeMask(caseId,dataDir,dims=self.scan.shape,border=self.border)
        print("loaded")
        
        #get the locations of the calsifications
        self.TP = np.where(self.mask==1.0)
        
        #get the negative sample locations
        self.TN = np.where(self.mask==0.0)
        
        #shuffle the case ids so the next on will be different
        shuffle(self.cases)
    
    def next(self):
        X = []
        Y = []
        
        #if there are not enough samples in the current file left (or no file is loaded) load a new scan
        if(len(self.TP)<self.samplesPerClass):
            #load the first case id
            print("loading scan")
            self.loadScan(self.cases[0])
            print("scan loaded")
        
        #select a number of positive and negative samples
        print("select samples")
        Pidx = list(range(len(self.TP[0])))
        shuffle(Pidx)
        P_samples = zip(self.TP[0][Pidx[:self.samplesPerClass]],self.TP[1][Pidx[:self.samplesPerClass]],self.TP[2][Pidx[:self.samplesPerClass]])
        
        Nidx = list(range(len(self.TN[0])))
        shuffle(Nidx)
        N_samples = zip(self.TN[0][Nidx[:self.samplesPerClass]],self.TN[1][Nidx[:self.samplesPerClass]],self.TN[2][Nidx[:self.samplesPerClass]])
        
        print("samples selected")
        
        #make the positive patches
        print("make positive patches")
        for loc in P_samples:
            patch = getPatch(loc,self.scan,patchSize=self.patch_size)
            X.append(patch)
            Y.append((1.0,0.0))
        print("positive patches made")
            
        #make the positive patches
        print("make negative patches")
        for loc in N_samples:
            patch = getPatch(loc,self.scan,patchSize=self.patch_size)
            X.append(patch)
            Y.append((0.0,1.0))
        print("negative patches made")
        
        print("stack patches")
        batch_x = np.stack(X).astype('float32')
        batch_y = np.stack(Y).astype('float32')
        print("patches stacked")
        return batch_x,batch_y
            

In [None]:
gen = PatchGenerator(dataDir,32,1,(21,21,1))

for batch_x, batch_y in gen:
    print(batch_x.shape)
    print(batch_y.shape)
    plot_image(batch_x, images_per_row=8)
    break

loading scan
loaded
scan loaded
select samples


In [None]:
TP = np.where(mask==1.0)
TP = [(x,y,z) for x,y,z in zip(TP[0],TP[1],TP[2])]

patches = [getPatch(loc,scan,patchSize=(11,11,1)) for loc in TP]
patchMasks = [getPatch(loc,mask,patchSize=(11,11,1)) for loc in TP]

print(patch.shape)
print(patchMask.shape)

In [None]:
from matplotlib import pyplot as plt
plt.close("all")

for patch,mask in zip(patches,patchMasks):
    plt.figure()
    plt.subplot(121)
    plt.imshow(patch)
    plt.subplot(122)
    plt.imshow(mask)
    plt.show()