In [None]:
import h5py
import nibabel as nib
import numpy as np
import os
import math
from PIL import Image
from skimage import measure
import matplotlib.pyplot as plt
import random
from tqdm import tqdm

#Only used for SimCLR dataset creation
#from DataAugmentations import preprocess_for_train
#import tensorflow as tf

In [None]:
"""
This script generates an arbitrary number of datasets from given directories in volumeDirs and segmentationDirs.
These ith directory in each must contain the volumes and/or segmentations for the ith dataset. 
The number of .nii files in the volume and segmentation directories for each dataset must be equal and the order must be the same (i.e. the first segmentation .nii file and the first volume .nii file must be from the same scan).
"""

volumeDirs = ["Datasets/RawData/TestingData/TestingVolumes/", "Datasets/RawData/TrainingData/TrainingVolumes/", "Datasets/RawData/ValidationData/ValidationVolumes/"]
segmentationDirs = ["Datasets/RawData/TestingData/TestingSegmentations/", "Datasets/RawData/TrainingData/TrainingSegmentations/", "Datasets/RawData/ValidationData/ValidationSegmentations/"]

fileNames = ["FullLiTSTestingDataset.hdf5", "FullLiTSTrainingDataset.hdf5", "FullLiTSValidationDataset.hdf5"]

#The number of CT scans used in each dataset is logged but not used in the rest of the script
numFiles = []
for fileName in volumeDirs:
    numFiles.append(len([name for name in os.listdir(fileName) if os.path.isfile(os.path.join(fileName, name))]))

#Percent of slices to keep from each scan, starts from middle of array
keepRate = 0.3

#Resize all slices/segmentations to imageDim x imageDim
imageDim = 256

In [None]:
#Standard window-leveling performed on all slices
def window_level(vol, window_center, window_width): 
    img_min = window_center - window_width // 2 
    img_max = window_center + window_width // 2 
    window_image = vol.copy() 
    window_image[window_image < img_min] = img_min 
    window_image[window_image > img_max] = img_max 

    return window_image 

In [None]:
#Calculates the global minimum and maximum values for all datasets, used in normalization
minVal = float('inf')
maxVal = float('-inf')

for volumeDir in volumeDirs:
    for i, name in enumerate(os.listdir(volumeDir)):
        #Disregards hidden files
        if name[0] == '.':
            continue

        ctScan = nib.load(volumeDir + name)
        volumeData = ctScan.get_fdata()

        minVal = min(np.amin(volumeData), minVal)
        maxVal = max(np.amax(volumeData), maxVal)

In [None]:
#Hard-coded maximum and minimum values for full LiTS dataset because recalculating is very slow
minVal = -3055
maxVal = 5811

# **Standard Datasets**

In [None]:
#Run cell to create binary datasets based on the LiTS format
livers = []
total = []

for i in range(len(volumeDirs)):
    volumeDir = volumeDirs[i]
    segmentationDir = segmentationDirs[i]

    numLivers = 0
    totalSlices = 0

    file = h5py.File(fileNames[i], "w")

    sliceNum = 0

    for i, name in tqdm(enumerate(os.listdir(volumeDir))):
        #Disregards hidden files
        if name[0] == '.':
            continue

        #Loads segmentation and volume data from .nii file
        ctScan = nib.load(volumeDir + name)
        volumeData = ctScan.get_fdata()

        volumeData = window_level(volumeData, 40, 400)

        segmentation = nib.load(segmentationDir + os.listdir(segmentationDir)[i])
        segmentData = segmentation.get_fdata()

        #Loops through all usable slices and adds data to h5 file
        #Finds middle index, subtracts half * keepRate from it, goes to middle index + half * keepRate
        for plane in range(math.ceil(((volumeData.shape[2] - 1) / 2) - (((volumeData.shape[2] - 1) / 2) * keepRate)), 
        math.floor(((volumeData.shape[2] - 1) / 2) + (((volumeData.shape[2] - 1) / 2) * keepRate))):

            volumeSlice = np.array(Image.fromarray(volumeData[:,:,plane].astype(np.int16)).resize((imageDim, imageDim), Image.BILINEAR))
            segmentSlice = segmentData[:,:,plane].astype(np.int16)

            volumeSlice = volumeSlice.astype(np.float16)

            volumeSlice -= minVal
            volumeSlice /= maxVal - minVal

            #Gets max value of current segmenation, limits it to 1 (1 if contains liver, 0 if not)
            label = min(np.amax(segmentSlice), 1)
            segmentSlice = np.array(Image.fromarray(segmentSlice).resize((imageDim, imageDim), Image.NEAREST))

            numLivers += label
            totalSlices += 1

            #Creates subgroup for current slice in current scan, adds slice/segmentation/label data
            currSlice = file.create_group("Slice" + str(sliceNum))
            currSlice.create_dataset("Slice", data=volumeSlice)
            currSlice.create_dataset("Segmentation", data=segmentSlice)
            currSlice.attrs.create("ImageLabel", label, (1,), "int")

            sliceNum += 1

    livers.append(numLivers)
    total.append(totalSlices)

print(f"Liver Present: {livers} Total: {total}")

file.close()

In [None]:
#Run cell to create multiclass datasets based on the format of the LiTS dataset
livers = []
total = []

for i in range(len(volumeDirs)):
    volumeDir = volumeDirs[i]
    segmentationDir = segmentationDirs[i]

    numLivers = 0
    totalSlices = 0

    file = h5py.File(fileNames[i], "w")

    sliceNum = 0

    for i, name in enumerate(os.listdir(volumeDir)):
        #Disregards hidden files
        if name[0] == '.':
            continue

        #Loads segmentation and volume data from .nii file
        ctScan = nib.load(volumeDir + name)
        volumeData = ctScan.get_fdata()

        volumeData = window_level(volumeData, 40, 400)

        segmentation = nib.load(segmentationDir + os.listdir(segmentationDir)[i])
        segmentData = segmentation.get_fdata()

        #Loops through all usable slices and adds data to h5 file
        #Finds middle index, subtracts half * keepRate from it, goes to middle index + half * keepRate
        for plane in range(math.ceil(((volumeData.shape[2] - 1) / 2) - (((volumeData.shape[2] - 1) / 2) * keepRate)), 
        math.floor(((volumeData.shape[2] - 1) / 2) + (((volumeData.shape[2] - 1) / 2) * keepRate))):

            volumeSlice = np.array(Image.fromarray(volumeData[:,:,plane].astype(np.int16)).resize((imageDim, imageDim), Image.BILINEAR))
            segmentSlice = segmentData[:,:,plane].astype(np.int16)

            volumeSlice = volumeSlice.astype(np.float16)

            volumeSlice -= minVal
            volumeSlice /= maxVal - minVal

            #Gets max value of current segmenation, limits it to 1 (1 if contains liver, 0 if not)
            label = min(np.amax(segmentSlice), 1)
            segmentSlice = np.array(Image.fromarray(segmentSlice).resize((imageDim, imageDim), Image.NEAREST))
            backgroundSegment = (segmentSlice == 0).astype(int)
            liverSegment = (segmentSlice == 1).astype(int)
            tumorSegment = (segmentSlice == 2).astype(int)

            numLivers += label
            totalSlices += 1

            #Creates subgroup for current slice in current scan, adds slice/segmentation/label data
            currSlice = file.create_group("Slice" + str(sliceNum))
            currSlice.create_dataset("Slice", data=volumeSlice)
            currSlice.create_dataset("BackgroundSegmentation", data=backgroundSegment)
            currSlice.create_dataset("LiverSegmentation", data=liverSegment)
            currSlice.create_dataset("TumorSegmentation", data=tumorSegment)
            currSlice.attrs.create("ImageLabel", label, (1,), "int")

            sliceNum += 1

    livers.append(numLivers)
    total.append(totalSlices)

print(f"Liver Present: {livers} Total: {total}")

file.close()

In [None]:
#Prints out all slice information and segmentation maps for a binary dataset
#Used to make sure any obvious errors didn't happen
for fileName in fileNames:
    dataFile = h5py.File(fileName, 'r')
    print(list(dataFile.keys()))
        
    for slice in dataFile:
        print(dataFile[slice]["Slice"])
        print(dataFile[slice]["Segmentation"])
        print(dataFile[slice].attrs.get("ImageLabel"))

dataFile.close()

# **Contrastive Datasets**

In [None]:
#Example selection for PolyCL-O
#Selects a random slice with specified label (targetLabel) from all slices in volDir/segmentDir
#Excludes all segmentation/volume files in the excludeFiles list
#Also excludes the current slice, determined by currVolumeName and currSliceNum
#Only tries to randomly select from each file 10 times, then excludes the file and tries again with a different file
#Performs all preprocessing (window-leveling, normalization) within this function

maxRandomIter = 10

def selectSlice(volDir, segmentDir, targetLabel, currVolumeName="", currSliceNum=-1, excludeFiles=[]):
    volumes = os.listdir(volDir)
    segmentations = os.listdir(segmentDir)

    for fileName in volumes:
        if fileName[0] == "." or fileName in excludeFiles:
            volumes.remove(fileName)

    for fileName in segmentations:
        if fileName[0] == "." or fileName in excludeFiles:
            segmentations.remove(fileName)

    if len(volumes) == 0 or len(segmentations) == 0:
        return selectSlice(volDir, segmentDir, targetLabel, currVolumeName=currVolumeName, currSliceNum=currSliceNum)

    scanInd = random.randrange(0, len(volumes))

    segmentation = nib.load(segmentDir + segmentations[scanInd])
    segmentData = segmentation.get_fdata()
    
    sliceInd = random.randrange(int((segmentData.shape[2] / 2) - (segmentData.shape[2] / 2 * keepRate)), int((segmentData.shape[2] / 2) + (segmentData.shape[2] / 2 * keepRate)))
    randomIter = 0
    while (min(np.amax(segmentData[:,:,sliceInd].astype(np.int16)), 1) != targetLabel or (sliceInd == currSliceNum and volumes[scanInd] == currVolumeName)) and randomIter <= maxRandomIter:
        randomIter += 1
        sliceInd = random.randrange(int((segmentData.shape[2] / 2) - (segmentData.shape[2] / 2 * keepRate)), int((segmentData.shape[2] / 2) + (segmentData.shape[2] / 2 * keepRate)))

    if randomIter >= maxRandomIter:
        excludeFiles.append(volumes[scanInd])
        excludeFiles.append(segmentations[scanInd])
        return selectSlice(volDir, segmentDir, targetLabel, currVolumeName=currVolumeName, currSliceNum=currSliceNum, excludeFiles=excludeFiles)

    volumeScan = nib.load(volDir + volumes[scanInd])
    volumeData = volumeScan.get_fdata()
    volumeData = window_level(volumeData, 40, 400)

    volumeSlice = np.array(Image.fromarray(volumeData[:,:,sliceInd].astype(np.float64)).resize((imageDim, imageDim), Image.BILINEAR))

    volumeSlice -= float(minVal)
    volumeSlice /= float(maxVal - minVal)

    segmentSlice = segmentData[:,:,sliceInd].astype(np.int16)

    return volumeSlice, segmentSlice, volumes[scanInd]

In [None]:
#Positive and negative example selection processes for PolyCL-S

#Selects random slice from current volume, excluding the current slice
def selectSliceRandPos(volDir, segmentDir, currVolumeName, currSegmentName, currSliceNum):
    segmentation = nib.load(segmentDir + currSegmentName)
    segmentData = segmentation.get_fdata()

    sliceInd = random.randrange(int((segmentData.shape[2] / 2) - (segmentData.shape[2] / 2 * keepRate)), int((segmentData.shape[2] / 2) + (segmentData.shape[2] / 2 * keepRate)))
    while sliceInd == currSliceNum:
        sliceInd = random.randrange(int((segmentData.shape[2] / 2) - (segmentData.shape[2] / 2 * keepRate)), int((segmentData.shape[2] / 2) + (segmentData.shape[2] / 2 * keepRate)))

    volumeScan = nib.load(volDir + currVolumeName)
    volumeData = volumeScan.get_fdata()
    volumeData = window_level(volumeData, 40, 400)
    volumeSlice = np.array(Image.fromarray(volumeData[:,:,sliceInd].astype(np.float64)).resize((imageDim, imageDim), Image.BILINEAR))
    volumeSlice -= float(minVal)
    volumeSlice /= float(maxVal - minVal)

    segmentSlice = segmentData[:,:,sliceInd].astype(np.int16)

    return volumeSlice, segmentSlice, currVolumeName

#First selects random CT scan that's not the current scan
#Then selects random slice from that scan, performs preprocessing and returns it
def selectSliceRandNeg(volDir, segmentDir, currVolumeName):
    volumes = os.listdir(volDir)
    currVolInd = volumes.index(currVolumeName)

    volInd = random.randrange(0, len(volumes))
    while volInd == currVolInd:
        volInd = random.randrange(0, len(volumes))

    volumeScan = nib.load(volDir + volumes[volInd])
    volumeData = volumeScan.get_fdata()

    sliceInd = random.randrange(int((volumeData.shape[2] / 2) - (volumeData.shape[2] / 2 * keepRate)), int((volumeData.shape[2] / 2) + (volumeData.shape[2] / 2 * keepRate)))

    volumeData = window_level(volumeData, 40, 400)
    volumeSlice = np.array(Image.fromarray(volumeData[:,:,sliceInd].astype(np.float64)).resize((imageDim, imageDim), Image.BILINEAR))
    volumeSlice -= float(minVal)
    volumeSlice /= float(maxVal - minVal)

    segmentation = nib.load(segmentDir + os.listdir(segmentDir)[volInd])
    segmentData = segmentation.get_fdata()
    segmentSlice = segmentData[:,:,sliceInd].astype(np.int16)

    return volumeSlice, segmentSlice, volumes[currVolInd]

In [None]:
#Creates positive example for SimCLR dataset, requires tensorflow to use external code
def simCLRPos(volDir, currVolumeName, currSliceNum):
    volumeScan = nib.load(volDir + currVolumeName)
    volumeData = volumeScan.get_fdata()
    volumeData = window_level(volumeData, 40, 400)
    volumeSlice = np.array(Image.fromarray(volumeData[:,:,currSliceNum].astype(np.float64)).resize((imageDim, imageDim), Image.BILINEAR))
    volumeSlice = np.expand_dims(volumeSlice, axis=2)
    volumeSlice = tf.convert_to_tensor(volumeSlice)
    volumeSlice = preprocess_for_train(volumeSlice, 256, 256)
    volumeSlice = volumeSlice.numpy()
    return volumeSlice

In [None]:
#Run this cell to create a contrastive pre-training dataset for each of the PolyCL strategies or SimCLR

#0: PolyCL-O
#1: PolyCL-S
#2: SimCLR
datasetType = 0

#Very similar in structure to the fully supervised dataset creation
for i, datasetName in enumerate(fileNames):
    sliceNum = 0

    volumes = os.listdir(volumeDirs[i])
    segmentations = os.listdir(segmentationDirs[i])

    file = h5py.File(datasetName, 'w')

    for j, volumeName in enumerate(volumes):
        segmentName = segmentations[j]

        volumeScan = nib.load(volumeDirs[i] + volumeName)
        volumeData = volumeScan.get_fdata()
        volumeData = window_level(volumeData, 40, 400)

        segmentation = nib.load(segmentationDirs[i] + segmentName)
        segmentData = segmentation.get_fdata()

        for plane in tqdm(range(math.ceil(((volumeData.shape[2] - 1) / 2) - (((volumeData.shape[2] - 1) / 2) * keepRate)), 
        math.floor(((volumeData.shape[2] - 1) / 2) + (((volumeData.shape[2] - 1) / 2) * keepRate)))):
            sliceVolume = np.array(Image.fromarray(volumeData[:,:,plane].astype(np.float64)).resize((imageDim, imageDim), Image.BILINEAR))
            sliceVolume -= float(minVal)
            sliceVolume /= float(maxVal - minVal)

            sliceSegment = segmentData[:,:,plane].astype(np.int16)

            label = min(np.amax(sliceSegment), 1)

            #Uses different example selection strategies based on the type of dataset being created
            if datasetType == 0:
                positiveSlice, positiveSegment, positiveScan = selectSlice(volumeDirs[i], segmentationDirs[i], label, currVolumeName=volumeName, currSliceNum=plane)
                negativeSlice, negativeSegment, negativeScan = selectSlice(volumeDirs[i], segmentationDirs[i], 1 - label, currVolumeName=volumeName, currSliceNum=plane)
            elif datasetType == 1:
                positiveSlice, positiveSegment, positiveScan = selectSliceRandPos(volumeDirs[i], segmentationDirs[i], volumeName, segmentName, plane)
                negativeSlice, negativeSegment, negativeScan = selectSliceRandNeg(volumeDirs[i], segmentationDirs[i], volumeName)
            elif datasetType == 2:
                positiveSlice = simCLRPos(volumeDirs[i], volumeName, plane)

            currGrp = file.create_group("Slice" + str(sliceNum))
            currGrp.create_dataset("MainSlice", data=sliceVolume)
            currGrp.create_dataset("PositiveSlice", data=positiveSlice)

            #Doesn't include any segmentation or negative example data for SimCLR datasets
            if datasetType != 2:
                currGrp.create_dataset("MainSegment", data=sliceSegment)
                currGrp.create_dataset("PositiveSegment", data=positiveSegment)
                currGrp.create_dataset("NegativeSlice", data=negativeSlice)
                currGrp.create_dataset("NegativeSegment", data=negativeSegment)
                currGrp.attrs.create("ImageLabel", label, (1,), "int")
                currGrp.attrs.create("PositiveScan", positiveScan)
                currGrp.attrs.create("NegativeScan", negativeScan)

            currGrp.attrs.create("MainScan", volumeName)

            sliceNum += 1

        print("Finished scan: " + volumeName)

    print("Finished dataset: " + datasetName)

# **Visualization**

In [None]:
#Simple function for clamping segmentation maps between 0 and 1 for visualizing them as binary
def clamp(iter, min, max):
    for i in range(len(iter)):
        for j in range(len(iter[i])):
            if iter[i][j] > max:
                iter[i][j] = max
            elif iter[i][j] < min:
                iter[i][j] = min

    return iter

In [None]:
#Loads file into memory and prints relevant information about the file (number of slices, which slices contain the liver, what hdf5 datasets I can reference for each slice)
sliceNum = 300
dataFile = h5py.File("MSDTestingDataset.hdf5", 'r')

temp = []
for i, key in enumerate(dataFile.keys()):
    #if np.amax(file[key]["Segmentation"]) == 2:
    if dataFile[key].attrs["ImageLabel"] == 1:
        temp.append(key)

print(temp)
print(len(dataFile.keys()))
print(dataFile["Slice" + str(sliceNum)].keys())

In [None]:
#Visualizes the current slice with contours from its segmentation
main = dataFile["Slice" + str(sliceNum)]["Slice"]
clamped = clamp(list(dataFile["Slice" + str(sliceNum)]["Segmentation"]), 0, 1)
groundTruthContours = measure.find_contours(clamped, 0.9)
fig, ax = plt.subplots()
ax.imshow(main, cmap=plt.cm.gray)

for contour in groundTruthContours:
    ax.plot(contour[:, 1], contour[:, 0], linewidth=2, color='red')

ax.axis('image')
ax.set_xticks([])
ax.set_yticks([])
plt.show()