In [1]:
import os
import cc3d
import torch
from monai.networks.nets import SwinUNETR
from torch.utils.data import DataLoader, random_split
import subprocess
import random
import scipy.ndimage as ndimage
from scipy.io import loadmat
import numpy as np
from monai.losses import DiceLoss
import nibabel as nib
import napari
from skimage import morphology
from cc_torch import connected_components_labeling

In [2]:
coarseModelPath = "/home/gabriela/Documents/gabiSegmentationProject/newCOPD_coarse_model/weights/19_v2_half.pth"
LRModelPath = "/home/gabriela/Documents/gabiSegmentationProject/newCOPD_LR_model/weights/14_v3_half.pth"
rightModelPath = "/home/gabriela/Documents/gabiSegmentationProject/newCOPD_right_model/weights/41_v2_half.pth"
leftModelPath = "/home/gabriela/Documents/gabiSegmentationProject/newCOPD_left_model/weights/17_v3_half.pth"

In [3]:
class oneFolderFinalDataset(torch.utils.data.Dataset):
    def __init__(self, imageDirectory):
        #self.imageList = sorted([file for file in os.listdir(imageDirectory) if "image" in file])
        self.imageList = sorted([file for file in os.listdir(imageDirectory)])
        #self.labelList = sorted([file for file in os.listdir(imageDirectory) if "mask" in file])
        self.imageDirectory = imageDirectory
    
    def __len__(self):
        x = len(self.imageList)
        return x

    def __getitem__(self, idx):
        imagePath = os.path.join(self.imageDirectory, self.imageList[idx])
        #labelPath = os.path.join(self.imageDirectory, self.labelList[idx])
        #image = nib.load(imagePath).get_fdata()
        matfile = loadmat(imagePath)
        inhale = np.array(matfile['T00'], dtype=np.int32)
        exhale = np.array(matfile['T50'], dtype=np.int32)
        print(np.min(inhale),np.max(inhale))
        inhale = 1 + ((inhale - 1024) / 1000)
        inhale = np.clip(inhale, 0, 1) * 255
        #label = nib.load(labelPath).get_fdata()
        inhale = torch.from_numpy(inhale).float().unsqueeze(0)
        exhale = 1 + ((exhale-1024) / 1000)
        exhale = np.clip(exhale, 0, 1) * 255
        #label = nib.load(labelPath).get_fdata()
        exhale = torch.from_numpy(exhale).float().unsqueeze(0)
        #label = torch.from_numpy(label).float().unsqueeze(0)

        return inhale, exhale, self.imageList[idx]
    
    def findIndex(self,filename):
        for i in range(len(self.imageList)):
            if filename == self.imageList[i]:
                return i
        return -1
    
def display_nvidia_smi_memory_usage():
    try:
        result = subprocess.run(['nvidia-smi', '--query-gpu=memory.total,memory.used,memory.free', '--format=csv,nounits,noheader'], check=True, text=True, stdout=subprocess.PIPE)
        output = result.stdout.strip().split('\n')
        for i, line in enumerate(output):
            total, used, free = map(int, line.split(','))
            print(f"GPU {i}: Total Memory: {total} MiB, Used Memory: {used} MiB, Free Memory: {free} MiB")
    except subprocess.CalledProcessError:
        print("Failed to execute nvidia-smi. Make sure you have the utility installed and that you have NVIDIA GPUs.")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

In [4]:
coarseModel = SwinUNETR(img_size=(128,128,128), in_channels=1, out_channels=2, feature_size=12)
coarseModel.load_state_dict(torch.load(coarseModelPath))
coarseModel.eval()
LRModel = SwinUNETR(img_size=(128,128,128), in_channels=1, out_channels=3, feature_size=12)
LRModel.load_state_dict(torch.load(LRModelPath))
LRModel.eval()
rightModel = SwinUNETR(img_size=(128,128,128), in_channels=1, out_channels=4, feature_size=12)
rightModel.load_state_dict(torch.load(rightModelPath))
rightModel.eval()
leftModel = SwinUNETR(img_size=(128,128,128), in_channels=1, out_channels=3, feature_size=12)
leftModel.load_state_dict(torch.load(leftModelPath))
leftModel.eval()
print('done')

done


In [5]:
coarseModel.to("cuda:0")
LRModel.to("cuda:0")
rightModel.to("cuda:0")
leftModel.to("cuda:0")
print('moved to cuda')

moved to cuda


In [6]:
display_nvidia_smi_memory_usage()

GPU 0: Total Memory: 49140 MiB, Used Memory: 2078 MiB, Free Memory: 46598 MiB
GPU 1: Total Memory: 49140 MiB, Used Memory: 13 MiB, Free Memory: 48672 MiB


In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
dataset = oneFolderFinalDataset("/home/gabriela/Documents/validation_cases/original_cancer_cases")
total_samples = len(dataset)
#train_size = int(0.8 * total_samples)
#val_size = total_samples - train_size
#train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator = torch.Generator().manual_seed(42))
loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

cuda:0


In [8]:
def bbox2_3D(img,margin=5):
    r = np.any(img, axis=(1, 2))
    c = np.any(img, axis=(0, 2))
    z = np.any(img, axis=(0, 1))
    xmin, xmax = np.where(r)[0][[0, -1]]
    ymin, ymax = np.where(c)[0][[0, -1]]
    zmin, zmax = np.where(z)[0][[0, -1]]
    xmin = max(0,xmin-margin)
    xmax = min(img.shape[0],xmax+margin)
    ymin = max(0,ymin-margin)
    ymax = min(img.shape[1],ymax+margin)
    zmin = max(0,zmin-margin)
    zmax = min(img.shape[2],zmax+margin)
    return xmin, xmax, ymin, ymax, zmin, zmax

def getLargest(arr, num=None):
    if num is None:
        num = len(np.unique(arr))-1
    # Use cc3d to find connected components
    connected_components = cc3d.connected_components(arr)

    # Count volume for each component
    volume_count = {}
    for component in np.unique(connected_components):
        if component == 0:  # background
            continue
        volume_count[component] = np.sum(connected_components == component)

    # Sort by volume
    sorted_components = sorted(volume_count.items(), key=lambda x: x[1], reverse=True)

    # Take two largest components
    largest_components = [comp_id for comp_id, _ in sorted_components[:num]]

    # Create a new 3D array with only the two largest components
    new_array_3d = np.zeros_like(arr)
    for comp_id in largest_components:
        new_array_3d[connected_components == comp_id] = np.unique(arr[connected_components == comp_id])[0]   # or comp_id to retain original labels

    return new_array_3d


def rescaleBounds(bounds,originalImageShape,imageShape):
    bounds = list(bounds)
    xScale = originalImageShape[0]/imageShape[0]
    yScale = originalImageShape[1]/imageShape[1]
    zScale = originalImageShape[2]/imageShape[2]
    bounds[0] = int(bounds[0]*xScale)
    bounds[1] = int(bounds[1]*xScale)
    bounds[2] = int(bounds[2]*yScale)
    bounds[3] = int(bounds[3]*yScale)
    bounds[4] = int(bounds[4]*zScale)
    bounds[5] = int(bounds[5]*zScale)
    return bounds

def fitInBounds(rescaledBounds, intermediateImageBounds):
    finalBounds = list(rescaledBounds)
    
    # The intermediateImageBounds should have format [x_min, x_max, y_min, y_max, z_min, z_max]
    
    # adjust the x coordinates
    finalBounds[0] = rescaledBounds[0] + intermediateImageBounds[0]
    finalBounds[1] = rescaledBounds[1] + intermediateImageBounds[0]
    
    # adjust the y coordinates
    finalBounds[2] = rescaledBounds[2] + intermediateImageBounds[2]
    finalBounds[3] = rescaledBounds[3] + intermediateImageBounds[2]
    
    # adjust the z coordinates
    finalBounds[4] = rescaledBounds[4] + intermediateImageBounds[4]
    finalBounds[5] = rescaledBounds[5] + intermediateImageBounds[4]
    
    return finalBounds

    

In [9]:
def compute_iou(array1, array2, num_classes):
    iou_list = []
    
    for cls in range(num_classes):
        # Create binary maps for the current class
        binary_array1 = (array1 == cls)
        binary_array2 = (array2 == cls)
        
        # Calculate intersection and union
        intersection = np.logical_and(binary_array1, binary_array2).sum()
        union = np.logical_or(binary_array1, binary_array2).sum()
        
        # Calculate IoU
        if union == 0:
            iou = 1.0  # If both arrays have 0 instances for this class, consider it a perfect match.
        else:
            iou = intersection / union
            
        iou_list.append(iou)
        
    return iou_list

In [10]:
def getModelOutput(input,model):
    softmax = torch.nn.Softmax(dim=1)
    with torch.no_grad():
        output = model(input)
        output = softmax(output)
        output = torch.argmax(output,dim=1)
        output = output.squeeze(0).cpu().numpy()
    return output


def connectedComponents(arr, threshold=3000, connectivity = 26,num_override=0):
    arr = cc3d.dust(arr,threshold=threshold,connectivity=connectivity,in_place=False)
    if num_override:
        arr = getLargest(arr,num_override)
    else:
        arr = getLargest(arr)
    return arr

def crop(originalImage, bounds):
    return originalImage[bounds[0]:bounds[1],bounds[2]:bounds[3],bounds[4]:bounds[5]]

def prepInput(image,device):
    image = ndimage.zoom(image, (128/image.shape[0], 128/image.shape[1], 128/image.shape[2]), order=0)
    image = torch.from_numpy(image).float().unsqueeze(0).unsqueeze(0).to(device)
    return image

def bodyMorphology(image):
    smallImage = ndimage.zoom(image, (64/image.shape[0], 64/image.shape[1], 64/image.shape[2]), order=0)
    imageThresholded = np.where(smallImage > 128, 1, 0)
    distances = ndimage.distance_transform_edt(imageThresholded)
    shrunkMask = distances > 1
    workingImage = shrunkMask
    for i in range(3):
        workingImage = ndimage.binary_dilation(workingImage,structure=morphology.ball(1),iterations=2)

        workingImage = ndimage.binary_erosion(workingImage,structure=morphology.ball(1),iterations=1)

        workingImage = ndimage.binary_closing(workingImage,structure=morphology.ball(5),iterations=1)
        workingImage = np.where(shrunkMask,1,workingImage)

    workingImage = ndimage.binary_erosion(workingImage,structure=morphology.ball(1),iterations=1)
    workingImage = ndimage.zoom(workingImage, (image.shape[0]/64, image.shape[1]/64, image.shape[2]/64), order=0)

    return workingImage

def newMorphology(image):
    smallImage = ndimage.zoom(image,(128/image.shape[0],128/image.shape[1],128/image.shape[2]),order=0)
    thresholdedImage = smallImage > 128
    #eroded = ndimage.binary_erosion(dustedImage,structure=morphology.ball(1))
    flipped = np.logical_not(thresholdedImage)
    background = connectedComponents(flipped)
    final = np.logical_not(background)
    final = ndimage.binary_erosion(final,structure=morphology.ball(2))
    return final



In [29]:
def pytorchGetLargest(tensor, num = None, threshold=3000, ignoreBackground=True,override_comp_id=None):
    if len(tensor.shape) == 4:
        tensor = tensor.squeeze(0)
    if len(tensor.shape) == 5:
        tensor = tensor.squeeze(0).squeeze(0)
    if tensor.dtype != torch.uint8:
        tensor = tensor.to(torch.uint8)
    test = (tensor > 0).to(torch.uint8)
    connected_components = connected_components_labeling(test)
    if override_comp_id is None:
        volume_count = {}
        for component in connected_components.unique():
            if component == 0 and ignoreBackground:  # background
                continue
            count = torch.sum(connected_components == component).item()
            if count < threshold:
                continue
            volume_count[component.item()] = count
        if num is None:
            num = len(volume_count)
        # Sort by volume
        sorted_components = sorted(volume_count.items(), key=lambda x: x[1], reverse=True)

        # Take two largest components
        largest_components = [comp_id for comp_id, _ in sorted_components[:num]]
    else:
        largest_components = override_comp_id

    # Create a new 3D array with only the two largest components
    new_array_3d = torch.zeros_like(tensor)
    for comp_id in largest_components:
        new_array_3d[connected_components == comp_id] = tensor[connected_components == comp_id]   # or comp_id to retain original labels

    return new_array_3d.unsqueeze(0).unsqueeze(0)
    
def torchGetModelOutput(input,model):
    softmax = torch.nn.Softmax(dim=1)
    with torch.no_grad():
        output = model(input)
        output = softmax(output)
        output = torch.argmax(output,dim=1).unsqueeze(0)
    return output.to(torch.uint8)

def pytorchBinaryErosion(tensor):
    ball = morphology.ball(2)
    struct_elem = torch.tensor(ball, dtype=torch.float32)
    struct_elem = struct_elem.view(1, 1, *struct_elem.size()).cuda()

    # Perform 3D convolution with structuring element
    conv_result = torch.nn.functional.conv3d(tensor.float(), struct_elem, padding=2)

    # Binary erosion is equivalent to finding where the convolution result
    # is equal to the sum of the structuring element
    erosion_result = (conv_result == struct_elem.sum().item()).float()
    return erosion_result

def torchMorphology(smallTensor):
    thresholdedTensor = (smallTensor > 128)
    flippedTensor = torch.logical_not(thresholdedTensor).to(torch.uint8)
    ccTensor = pytorchGetLargest(flippedTensor,ignoreBackground=False,override_comp_id=[flippedTensor[:,:,0,0,0].item()])
    finalTensor = pytorchBinaryErosion(torch.logical_not(ccTensor))
    return finalTensor

def torchbbox2_3D(img,margin=5):
    if len(img.shape) > 3:
        for i in range(len(img.shape)-3):
            img = img.squeeze(0)
    r = img.any(dim=2).any(dim=1)
    c = img.any(dim=2).any(dim=0)
    z = img.any(dim=1).any(dim=0)
    xmin, xmax = torch.where(r)[0][[0, -1]]
    ymin, ymax = torch.where(c)[0][[0, -1]]
    zmin, zmax = torch.where(z)[0][[0, -1]]
    
    xmin = max(0,xmin-margin)
    xmax = min(img.shape[0],xmax+margin)
    ymin = max(0,ymin-margin)
    ymax = min(img.shape[1],ymax+margin)
    zmin = max(0,zmin-margin)
    zmax = min(img.shape[2],zmax+margin)
    return int(xmin), int(xmax), int(ymin), int(ymax), int(zmin), int(zmax)

def torchPrep(image):
    print(image.shape)
    image = torch.nn.functional.interpolate(image,size=(128,128,128),mode='nearest')
    return image

def pytorchBinaryDilation(tensor, selem_radius=3):
    ball = morphology.ball(selem_radius)
    struct_elem = torch.tensor(ball, dtype=torch.float32)
    struct_elem = struct_elem.view(1, 1, *struct_elem.size()).cuda()

    # Perform 3D convolution with structuring element
    conv_result = torch.nn.functional.conv3d(tensor.float(), struct_elem, padding=selem_radius)

    # Binary dilation is equivalent to finding where the convolution result
    # is greater than 0
    dilation_result = (conv_result > 0).float()

    return dilation_result

def torchCrop(originalImage, bounds):
    return originalImage[:,:,bounds[0]:bounds[1],bounds[2]:bounds[3],bounds[4]:bounds[5]]

def torchRescaleBounds(bounds,originalImageShape,imageShape):
    bounds = list(bounds)
    xScale = originalImageShape[2]/imageShape[2]
    yScale = originalImageShape[3]/imageShape[3]
    zScale = originalImageShape[4]/imageShape[4]
    bounds[0] = int(bounds[0]*xScale)
    bounds[1] = int(bounds[1]*xScale)
    bounds[2] = int(bounds[2]*yScale)
    bounds[3] = int(bounds[3]*yScale)
    bounds[4] = int(bounds[4]*zScale)
    bounds[5] = int(bounds[5]*zScale)
    return bounds


In [12]:
def fullPipeline(originalImage,coarseModel,LRModel,leftModel,rightModel,device, diagnostics= [False,False,False,False,False,False]):
    selem = morphology.ball(3)
    originalImage = originalImage.to(device)
    coarseImage = torch.nn.functional.interpolate(originalImage,size=(128,128,128),mode='nearest')
    originalImage = originalImage.squeeze(0).squeeze(0).cpu().numpy()
    if diagnostics[0]:
       viewer1 = napari.view_image(originalImage)
    
    #Segment lung - coarse model outputs binary mask of what is lung and not
    coarseOutput = getModelOutput(coarseImage,coarseModel)
    coarseImage = coarseImage.squeeze(0).squeeze(0).cpu().numpy()
    bodyMask = newMorphology(coarseImage)
    coarseOutput = np.where(bodyMask,coarseOutput,0)
    coarseOutput = connectedComponents(coarseOutput,num_override=2)
    coarseBounds128 = bbox2_3D(coarseOutput)
    coarseBounds = rescaleBounds(coarseBounds128,originalImage.shape,coarseOutput.shape)
    coarseCropped = crop(originalImage,coarseBounds)
    if diagnostics[1]:
        coarseModelViewer= napari.view_image(coarseImage)
        coarseModelViewer.add_image(coarseOutput)
        viewer2 = napari.view_image(originalImage)
        coarseBox = np.zeros(originalImage.shape)
        coarseBox[coarseBounds[0]:coarseBounds[1],coarseBounds[2]:coarseBounds[3],coarseBounds[4]:coarseBounds[5]] = 1
        viewer2.add_image(coarseBox)
    
    #Segment LR - LR model outputs mask of what is left and right lung
    LRInput = prepInput(coarseCropped,device)
    LROutput = getModelOutput(LRInput,LRModel)
    LROutput = connectedComponents(LROutput)
    leftOutput = np.where(LROutput==1,1,0)
    rightOutput = np.where(LROutput==2,1,0)
    leftBounds128 = bbox2_3D(leftOutput,margin=1)
    rightBounds128 = bbox2_3D(rightOutput,margin=1)
    leftCoarseBounds = rescaleBounds(leftBounds128,coarseCropped.shape,leftOutput.shape)
    rightCoarseBounds = rescaleBounds(rightBounds128,coarseCropped.shape,rightOutput.shape)
    leftBounds = fitInBounds(leftCoarseBounds,coarseBounds)
    rightBounds = fitInBounds(rightCoarseBounds,coarseBounds)
    leftCropped = crop(originalImage,leftBounds)
    rightCropped = crop(originalImage,rightBounds)

    LRFullSizeMask = np.zeros(originalImage.shape)
    LRCoarseSize = ndimage.zoom(LROutput, (coarseCropped.shape[0]/128,coarseCropped.shape[1]/128,coarseCropped.shape[2]/128), order=0)
    LRFullSizeMask[coarseBounds[0]:coarseBounds[1],coarseBounds[2]:coarseBounds[3],coarseBounds[4]:coarseBounds[5]] = LRCoarseSize
    LRFullSizeMask = np.where(LRFullSizeMask,1,0)
    LRMaskDilated = morphology.binary_dilation(LRFullSizeMask,footprint=selem)

    if diagnostics[2]:
       viewer3 = napari.view_image(originalImage)
       viewer3.add_image(LRFullSizeMask,colormap="gist_earth",contrast_limits=(0,2),opacity=.5)
    
    # Get and post-process left lobe model output
    leftInput = prepInput(leftCropped,device)
    leftLobeOutput = getModelOutput(leftInput,leftModel)
    if diagnostics[3]:
        leftLobeViewer = napari.view_image(leftInput.squeeze(0).squeeze(0).cpu().numpy())
        leftLobeViewer.add_image(leftLobeOutput)
    leftLobeOutput = ndimage.zoom(leftLobeOutput, (leftCropped.shape[0]/128,leftCropped.shape[1]/128,leftCropped.shape[2]/128), order=0)
    
    # Get and post-process right lobe model output
    rightInput = prepInput(rightCropped,device)
    rightLobeOutput = getModelOutput(rightInput,rightModel)       
    if diagnostics[4]:
        rightLobeViewer = napari.view_image(rightInput.squeeze(0).squeeze(0).cpu().numpy())
        rightLobeViewer.add_image(rightLobeOutput)
    rightLobeOutput = ndimage.zoom(rightLobeOutput, (rightCropped.shape[0]/128,rightCropped.shape[1]/128,rightCropped.shape[2]/128), order=0)


    #adjust right lobe output (0,1,2,3) to (0,3,4,5)
    rightLobeOutput = rightLobeOutput + 2
    rightLobeOutput = np.where(rightLobeOutput==2,0,rightLobeOutput)
    
    #Assemble final mask
    finalMask = np.zeros(originalImage.shape)
    leftFullSize = np.zeros(originalImage.shape)
    rightFullSize = np.zeros(originalImage.shape)
    leftFullSize[leftBounds[0]:leftBounds[1],leftBounds[2]:leftBounds[3],leftBounds[4]:leftBounds[5]] = leftLobeOutput
    rightFullSize[rightBounds[0]:rightBounds[1],rightBounds[2]:rightBounds[3],rightBounds[4]:rightBounds[5]] = rightLobeOutput
    finalMask = np.where(leftFullSize, leftFullSize, finalMask)
    finalMask = np.where(rightFullSize, rightFullSize, finalMask)

    finalMask = np.where(LRMaskDilated,finalMask,0)
    #finalMask = connectedComponents(finalMask)

    finalMask = finalMask.astype(np.uint8)
    if diagnostics[5]:
        leftbox = np.zeros(originalImage.shape)
        leftbox[leftBounds[0]:leftBounds[1],leftBounds[2]:leftBounds[3],leftBounds[4]:leftBounds[5]] = 1
        rightbox = np.zeros(originalImage.shape)
        rightbox[rightBounds[0]:rightBounds[1],rightBounds[2]:rightBounds[3],rightBounds[4]:rightBounds[5]] = 1
        viewer4 = napari.view_image(originalImage)
        viewer4.add_image(finalMask, colormap="gist_earth",contrast_limits=(0,5),opacity=.5)
        viewer4.add_image(leftbox,colormap="red",opacity=.5)
        viewer4.add_image(rightbox,colormap="blue",opacity=.5)
    return finalMask
            

In [32]:
def fullPytorchPipeline(originalImage,coarseModel,LRModel,leftModel,rightModel,device):
    originalImage = originalImage.to(device)
    coarseImage = torch.nn.functional.interpolate(originalImage,size=(128,128,128),mode='nearest')    
    #Segment lung - coarse model outputs binary mask of what is lung and not
    coarseOutput = torchGetModelOutput(coarseImage,coarseModel)
    
    bodyMask = torchMorphology(coarseImage)
    coarseOutput = torch.where(bodyMask > 0,coarseOutput,0)
    coarseOutput = pytorchGetLargest(coarseOutput,num=2) #HWD
    
    coarseBounds128 = torchbbox2_3D(coarseOutput)
    coarseBounds = torchRescaleBounds(coarseBounds128,originalImage.shape,coarseOutput.shape)
    coarseCropped = torchCrop(originalImage,coarseBounds)
    
    #Segment LR - LR model outputs mask of what is left and right lung
    LRInput = torchPrep(coarseCropped) #HWD -> NCHWD
    LROutput = torchGetModelOutput(LRInput,LRModel)
    print('lr')
    LROutput = pytorchGetLargest(LROutput,num=2)
    LRFullSizeMask = torch.zeros(originalImage.shape).cuda()
    LRCoarseSize = torch.nn.functional.interpolate(LROutput,size=coarseCropped.shape[2:],mode='nearest-exact')
    LRFullSizeMask[:,:,coarseBounds[0]:coarseBounds[1],coarseBounds[2]:coarseBounds[3],coarseBounds[4]:coarseBounds[5]] = LRCoarseSize
    #iewer = napari.view_image(LROutput.cpu().numpy(),contrast_limits=(0,2))
    
    leftOutput = torch.where(LROutput==1,1,0)
    rightOutput = torch.where(LROutput==2,1,0)
    leftOutput = pytorchGetLargest(leftOutput,num=1)
    rightOutput = pytorchGetLargest(rightOutput,num=1)
    leftBounds128 = torchbbox2_3D(leftOutput,margin=1)
    rightBounds128 = torchbbox2_3D(rightOutput,margin=1)
    leftCoarseBounds = torchRescaleBounds(leftBounds128,coarseCropped.shape,leftOutput.shape)
    rightCoarseBounds = torchRescaleBounds(rightBounds128,coarseCropped.shape,rightOutput.shape)
    leftBounds = fitInBounds(leftCoarseBounds,coarseBounds)
    rightBounds = fitInBounds(rightCoarseBounds,coarseBounds)
    leftCropped = torchCrop(originalImage,leftBounds)
    rightCropped = torchCrop(originalImage,rightBounds)

    print(LROutput.shape)
    #LRCoarseSize = ndimage.zoom(LROutput, (coarseCropped.shape[0]/128,coarseCropped.shape[1]/128,coarseCropped.shape[2]/128), order=0)
    LRCoarseSize = torch.nn.functional.interpolate(LROutput,size=coarseCropped.shape[2:],mode='nearest')
    LRFullSizeMask[:,:,coarseBounds[0]:coarseBounds[1],coarseBounds[2]:coarseBounds[3],coarseBounds[4]:coarseBounds[5]] = LRCoarseSize
    LRFullSizeMask = torch.where(LRFullSizeMask > 0,1,0)
    LRMaskDilated = pytorchBinaryDilation(LRFullSizeMask)



    # Get and post-process left lobe model output
    leftInput = torchPrep(leftCropped)
    leftLobeOutput = torchGetModelOutput(leftInput,leftModel)
    leftLobeOutput = torch.nn.functional.interpolate(leftLobeOutput, size=leftCropped.shape[2:], mode='nearest')
    # Get and post-process right lobe model output
    rightInput = torchPrep(rightCropped)
    rightLobeOutput = torchGetModelOutput(rightInput,rightModel)       
    rightLobeOutput = torch.nn.functional.interpolate(rightLobeOutput, size=rightCropped.shape[2:], mode='nearest')


    #adjust right lobe output (0,1,2,3) to (0,3,4,5)
    rightLobeOutput = rightLobeOutput + 2
    rightLobeOutput = torch.where(rightLobeOutput==2,0,rightLobeOutput)
    
    #Assemble final mask
    finalMask = torch.zeros(originalImage.shape).cuda()
    leftFullSize = torch.zeros(originalImage.shape).cuda()
    rightFullSize = torch.zeros(originalImage.shape).cuda()
    leftFullSize[:,:,leftBounds[0]:leftBounds[1],leftBounds[2]:leftBounds[3],leftBounds[4]:leftBounds[5]] = leftLobeOutput
    rightFullSize[:,:,rightBounds[0]:rightBounds[1],rightBounds[2]:rightBounds[3],rightBounds[4]:rightBounds[5]] = rightLobeOutput
    finalMask = torch.where(leftFullSize > 0, leftFullSize, finalMask)
    finalMask = torch.where(rightFullSize > 0, rightFullSize, finalMask)

    finalMask = torch.where(LRMaskDilated > 0,finalMask,0)
    #finalMask = connectedComponents(finalMask)

    finalMask = finalMask.to(torch.uint8)
    return finalMask
            

In [30]:
import einops

In [14]:
matFile = loadmat("/home/gabriela/Documents/gabiSegmentationProject/B22T50_AdamWRMSScreening0_Sep13_162154corrected.mat")
image = matFile['B22T50_AdamWRMSScreening0_Sep13_162154']

In [154]:
print(matFile.keys())

dict_keys(['__header__', '__version__', '__globals__', 'B22T50_AdamWRMSScreening0_Sep13_162154'])


In [15]:
print(np.min(image),np.max(image))
image = 1 + ((image - 1024) / 1000)
image = np.clip(image, 0, 1) * 255

0 2061


In [17]:
image = torch.tensor(image).float().unsqueeze(0).unsqueeze(0)

In [None]:
output, segInput = fullPytorchPipeline(image,coarseModel,LRModel,leftModel,rightModel,device)
viewer = napari.view_image(segInput.cpu().numpy())
viewer.add_image(output.cpu().numpy(),contrast_limits=(0,2),colormap='gist_earth',opacity=.5)

In [33]:
imageSeg = fullPytorchPipeline(image,coarseModel,LRModel,leftModel,rightModel,device)
viewer = napari.view_image(image.cpu().numpy())
viewer.add_image(imageSeg.cpu().numpy(),contrast_limits=(0,5),colormap='gist_earth',opacity=.5)

torch.Size([1, 1, 188, 280, 89])
lr
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 1, 157, 146, 81])
torch.Size([1, 1, 156, 116, 82])


<Image layer 'Image [1]' at 0x7fb81d72fad0>

In [167]:
print(dataset.imageList)
inhale,exhale,filename = dataset[5]
inhale = inhale.unsqueeze(0)
exhale = exhale.unsqueeze(0)

['B10_00.mat', 'B14_00.mat', 'B18_00.mat', 'B20_00.mat', 'B21_00.mat', 'B22_00.mat', 'B23_00.mat', 'B24_00.mat', 'B25_00.mat', 'B26_00.mat', 'B27_00.mat', 'B28_00.mat', 'B29_00.mat', 'B30_00.mat', 'B31_00.mat', 'B32_00.mat', 'B37_00.mat', 'B38_00.mat', 'B39_00.mat', 'B40_00.mat', 'B41_00.mat', 'B43_00.mat', 'B44_00.mat', 'B45_00.mat', 'B9_00.mat']
0 2649


In [169]:
viewer = napari.view_image(exhale.cpu().numpy())
viewer.add_image(finalMask_exhale.cpu().numpy(),contrast_limits=(0,5),colormap='gist_earth',opacity=.5)

<Image layer 'Image [1]' at 0x7f1136cd5cd0>

In [168]:
finalMask_exhale = fullPytorchPipeline(exhale,coarseModel,LRModel,leftModel,rightModel,device)


52160
866
4
3
6
68
438
1
2
1
23
1
10
14
61
1
3
2
10
9
6
33
1
11
1
{628067: 52160}
torch.Size([1, 1, 188, 280, 89])
lr
2
463052
1
2
1
128
1
1
2
3
1
96
3
2
3
33
5
1
1
53
6
19
1
1
1
{201311: 463052}
262282
1
{201311: 262282}
200747
1
1
14
1
3
1
1
{273237: 200747}
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 1, 159, 146, 82])
torch.Size([1, 1, 154, 116, 83])


In [15]:
with torch.no_grad():
    for i, (inhale,exhale, filename) in enumerate(loader):
       # inhale = inhale.squeeze(0).squeeze(0).numpy()
        #print(np.min(inhale),np.max(inhale))
        diagnostics = [False,False,False,False,False,True]
        try:
            finalMask_inhale = fullPytorchPipeline(inhale,coarseModel,LRModel,leftModel,rightModel,device)
            #inhale_seg = nib.Nifti1Image(finalMask_inhale.astype(np.uint8), affine=np.eye(4))
            #nib.save(inhale_seg, f'/home/gabriela/Documents/validation_cases/inhale_masks/{filename}_inhale.nii')
            finalMask_exhale = fullPytorchPipeline(exhale,coarseModel,LRModel,leftModel,rightModel,device)
            #exhale_seg = nib.Nifti1Image(finalMask_exhale.astype(np.uint8), affine=np.eye(4))
            #nib.save(exhale_seg, f'/home/gabriela/Documents/validation_cases/exhale_masks/{filename}_exhale.nii')
            Viewer = napari.view_image(finalMask_inhale.cpu().numpy())
            if np.array(diagnostics).any():
                break
                
        except IndexError as e:
            print(e)
            print(filename)
            break

0 4095
0 4095
[0 1]
0
 4095{661593: 17970, 729669: 15007}
torch.Size([1, 1, 43, 60, 4])
torch.Size([1, 1, 128, 128, 128])
[0 1 2]
0 4095
{48425: 306843, 131083: 81818, 1457273: 31980}
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
torch.Size([1, 1, 128, 128, 128])
index is out of bounds for dimension with size 0
['B10_00.mat']


In [14]:
#calculate median, min, max, 25% 75% percentile of each iou
median = np.median(allIoUs,axis=0)
minimum = np.min(allIoUs,axis=0)
maximum = np.max(allIoUs,axis=0)
percentile25 = np.percentile(allIoUs,25,axis=0)
percentile75 = np.percentile(allIoUs,75,axis=0)
mean = np.mean(allIoUs,axis=0)
std = np.std(allIoUs,axis=0)
print(f"mean: {mean}")
print(f"std: {std}")
print(f"median: {median}")
print(f"minimum: {minimum}")
print(f"maximum: {maximum}")
print(f"percentile25: {percentile25}")
print(f"percentile75: {percentile75}")

#count rows with IoUs below .8
count = 0
completeFailCount = 0
for row in allIoUs:
    flag = np.array([True for val in row if val < 0.8]).any()
    completeFail = np.array([True for val in row if val < .10]).any()
    if flag:
        count += 1
    if completeFail:
        completeFailCount += 1
print(f"Number of rows with at least one IoU below .8: {count}")
print(f"Number of rows with at least one IoU below .1: {completeFailCount}")


mean: [0.99539514 0.9478496  0.94130024 0.93314625 0.88585846 0.94128153]
std: [0.0013794  0.01497322 0.01926996 0.03128335 0.05807555 0.01709858]
median: [0.99547702 0.95141895 0.94552664 0.94149879 0.90300134 0.94462566]
minimum: [0.99123203 0.83449172 0.81887321 0.63325052 0.57119522 0.83333865]
maximum: [0.9986088  0.96713803 0.96954426 0.96500875 0.94338542 0.96561018]
percentile25: [0.99445298 0.94268888 0.93568856 0.92967128 0.87999438 0.93456552]
percentile75: [0.99646507 0.95708696 0.95356125 0.94964599 0.91893935 0.95257833]
Number of rows with at least one IoU below .8: 31
Number of rows with at least one IoU below .1: 0


index 0 is out of bounds for axis 0 with size 0
['17704Z_EXP_image.nii.gz']

In [17]:
badFiles = []
with open('../ious.txt','r') as iouf:
    lines = iouf.readlines()
    for line in lines:
        ious, filename = line.split('-')
        ious = ious.strip().strip('[').strip(']').split(',')
        ious = np.array([float(val) for val in ious])
        filename = filename.strip()
        if (ious < .8).any():
            #print(f"{filename} has at least one IoU below .8")
            print(f"{filename}: {ious}")
            badFiles.append(filename)
badFiles = [val[2:-2] for val in badFiles]
        

['22169K_EXP_image.nii.gz']: [0.99569558 0.95644109 0.94816801 0.89872881 0.76741352 0.94431173]
['15460N_EXP_image.nii.gz']: [0.99815988 0.93218118 0.90603531 0.93303655 0.73704107 0.89392002]
['12596X_INSP_image.nii.gz']: [0.9935656  0.96171443 0.95769247 0.87585702 0.75122895 0.95124937]
['11401D_EXP_image.nii.gz']: [0.9979965  0.92898779 0.9224925  0.83855884 0.75783769 0.91827999]
['11015Y_EXP_image.nii.gz']: [0.9945723  0.96174481 0.95938263 0.90511807 0.75851981 0.96037623]
['10795T_EXP_image.nii.gz']: [0.99700719 0.8665107  0.82409845 0.83751104 0.57155334 0.91124893]
['16553Z_EXP_image.nii.gz']: [0.99534504 0.9406398  0.92054885 0.91516137 0.63344579 0.86685302]
['12895H_EXP_image.nii.gz']: [0.99523863 0.91570598 0.93001188 0.93450595 0.74939736 0.94516784]
['19370G_INSP_image.nii.gz']: [0.99365281 0.96236574 0.96522614 0.87827307 0.73182768 0.96152413]
['18558T_INSP_image.nii.gz']: [0.99521442 0.95785104 0.95418646 0.63325052 0.57119522 0.95869274]
['11659Q_INSP_image.nii.gz'

In [20]:
import pandas as pd

df = pd.read_csv('/home/aaronluong/pft/twoTimepoint.txt',sep='\t')


Columns (75,447,674,848,855) have mixed types. Specify dtype option on import or set low_memory=False.


In [34]:
phase = ['INSP' in val for val in badFiles]
print(np.unique(phase,return_counts=True))

(array([False,  True]), array([19, 12]))


In [31]:
badPatientIds = [val[:6] for val in badFiles]
badSubset = df[df['sid'].isin(badPatientIds)]
print(badSubset.loc[:,['finalgold_baseline']].value_counts())

finalgold_baseline
2.0                   9
3.0                   9
0.0                   7
4.0                   3
1.0                   2
dtype: int64


In [18]:


indices = [dataset.findIndex(filename) for filename in badFiles]
print(indices)

[1490, 682, 319, 168, 120, 94, 794, 348, 1179, 1081, 191, 1052, 203, 1801, 1492, 1518, 1198, 1680, 1576, 732, 515, 791, 394, 271, 22, 491, 100, 237, 1503, 790, 240]


In [46]:
viewer = napari.Viewer()
idx = indices[1]
originalImage,originalLabel, filename = dataset[idx]
originalLabel = originalLabel.to(device)
originalImage = originalImage.unsqueeze(0)
finalMask = fullPipeline(originalImage,coarseModel,LRModel,leftModel,rightModel,device,diagnostics=[False,False,False,True,True,False])
originalLabel = originalLabel.squeeze(0).squeeze(0).cpu().numpy()
IoUs = compute_iou(finalMask,originalLabel,6)
print(f"{IoUs} - {filename}")
diff = abs(finalMask - originalLabel)
print(np.unique(diff))
viewer.add_image(originalImage.squeeze(0).squeeze(0).cpu().numpy())
viewer.add_image(originalLabel,colormap="gist_earth",contrast_limits=(0,5),opacity=.5)
viewer.add_image(finalMask,colormap="gist_earth",contrast_limits=(0,5),opacity=.5)
viewer.add_image(diff,colormap="red",opacity=.5,contrast_limits=(0,1))
    

`selem` is a deprecated argument name for `binary_dilation`. It will be removed in version 1.0. Please use `footprint` instead.


[0.9982706631260615, 0.9340230302255383, 0.9073126849027756, 0.9362145931751, 0.7379043964674135, 0.8966783809776152] - 15460N_EXP_image.nii.gz
[0. 1. 2. 3. 4. 5.]


<Image layer 'diff' at 0x7f316bae2940>

In [14]:
class testdataset(torch.utils.data.Dataset):
    def __init__(self, imageDirectory):
        self.imageList = sorted([file for file in os.listdir(imageDirectory) if "image" in file])
        #self.imageList = sorted([file for file in os.listdir(imageDirectory)])
        self.labelList = sorted([file for file in os.listdir(imageDirectory) if "mask" in file])
        self.imageDirectory = imageDirectory
    
    def __len__(self):
        x = len(self.imageList)
        return x

    def __getitem__(self, idx):
        imagePath = os.path.join(self.imageDirectory, self.imageList[idx])
        labelPath = os.path.join(self.imageDirectory, self.labelList[idx])
        image = nib.load(imagePath).get_fdata()
        image = 1+ image/1000
        image = np.clip(image,0,1) * 255
        label = nib.load(labelPath).get_fdata()
        #matfile = loadmat(imagePath)
        #inhale = np.array(matfile['T00'], dtype=np.int32)
        #exhale = np.array(matfile['T50'], dtype=np.int32)
        #print(np.min(inhale),np.max(inhale))
        #inhale = 1 + ((inhale - 1024) / 1000)
        #inhale = np.clip(inhale, 0, 1) * 255
        #label = nib.load(labelPath).get_fdata()
        #inhale = torch.from_numpy(inhale).float().unsqueeze(0)
        #exhale = 1 + ((exhale-1024) / 1000)
        #exhale = np.clip(exhale, 0, 1) * 255
        #label = nib.load(labelPath).get_fdata()
        #exhale = torch.from_numpy(exhale).float().unsqueeze(0)
        #label = torch.from_numpy(label).float().unsqueeze(0)

        return image, label
    
    def findIndex(self,filename):
        for i in range(len(self.imageList)):
            if filename == self.imageList[i]:
                return i
        return -1

In [15]:
from skimage import measure
import matplotlib.pyplot as plt

In [16]:
testDataset = testdataset("/home/gabriela/Documents/newCOPD_raw")
image, label = testDataset[5]
print(image.shape)

(512, 512, 534)


In [22]:


def bbox2_3D(img,margin=5):
    r = np.any(img, axis=(1, 2))
    c = np.any(img, axis=(0, 2))
    z = np.any(img, axis=(0, 1))
    xmin, xmax = np.where(r)[0][[0, -1]]
    ymin, ymax = np.where(c)[0][[0, -1]]
    zmin, zmax = np.where(z)[0][[0, -1]]
    print(zmin,zmax)
    xmin = max(0,xmin-margin)
    xmax = min(img.shape[0],xmax+margin)
    ymin = max(0,ymin-margin)
    ymax = min(img.shape[1],ymax+margin)
    zmin = max(0,zmin-margin)
    zmax = min(img.shape[2],zmax+margin)
    return xmin, xmax, ymin, ymax, zmin, zmax

test = label
assert (test == torch.tensor(test).numpy()).all()



viewer = napari.view_image(test)
viewer.add_image(torch.tensor(test).numpy())
torchDebugBox = np.zeros_like(test)
xmin,xmax,ymin,ymax,zmin,zmax = torchbbox2_3D(torch.tensor(test))
torchDebugBox[xmin:xmax,ymin:ymax,zmin:zmax] = 1
xmin,xmax,ymin,ymax,zmin,zmax = bbox2_3D(test)
npDebugBox = np.zeros_like(test)
npDebugBox[xmin:xmax,ymin:ymax,zmin:zmax] = 1
viewer.add_image(npDebugBox)
viewer.add_image(torchDebugBox)
print(test.shape)
print(torchbbox2_3D(torch.tensor(test)))
print(bbox2_3D(test))

25 533
(512, 512, 534)
(87, 380, 35, 482, 20, 534)
25 533
(87, 380, 35, 482, 20, 534)


In [16]:
smallImage = ndimage.zoom(image,(128/image.shape[0],128/image.shape[1],128/image.shape[2]),order=0)
thresholdedImage = smallImage > 128
#eroded = ndimage.binary_erosion(dustedImage,structure=morphology.ball(1))
flipped = np.logical_not(thresholdedImage)
background = connectedComponents(flipped)
final = np.logical_not(background)
final = ndimage.binary_erosion(final,structure=morphology.ball(2))

<Image layer 'final' at 0x7f0446bb4250>

In [13]:
def pytorchGetLargest(tensor, num = None, ignoreBackground=True,override_comp_id=None):
    connected_components = connected_components_labeling(tensor)
    if override_comp_id is None:
        if num is None:
            num = len(tensor.unique()-1)
        tensorView = tensor.cpu().numpy()
        print(np.unique(tensorView))
        viewer = napari.view_image(tensorView)
        viewer.add_image(connected_components.cpu().numpy())
        volume_count = {}
        for component in connected_components.unique():
            if component == 0 and ignoreBackground:  # background
                continue
            volume_count[component.item()] = torch.sum(connected_components == component).item()

        print(volume_count)

        # Sort by volume
        sorted_components = sorted(volume_count.items(), key=lambda x: x[1], reverse=True)

        # Take two largest components
        largest_components = [comp_id for comp_id, _ in sorted_components[:num]]
    else:
        largest_components = override_comp_id

    # Create a new 3D array with only the two largest components
    new_array_3d = torch.zeros_like(tensor)
    for comp_id in largest_components:
        new_array_3d[connected_components == comp_id] = tensor[connected_components == comp_id].unique()[0]   # or comp_id to retain original labels

    return new_array_3d
    

def pytorchBinaryErosion(tensor):
    ball = morphology.ball(2)
    struct_elem = torch.tensor(ball, dtype=torch.float32)
    struct_elem = struct_elem.view(1, 1, *struct_elem.size()).cuda()

    # Perform 3D convolution with structuring element
    conv_result = torch.nn.functional.conv3d(tensor.float(), struct_elem, padding=2)

    # Binary erosion is equivalent to finding where the convolution result
    # is equal to the sum of the structuring element
    erosion_result = (conv_result == struct_elem.sum().item()).float()
    return erosion_result

In [84]:
tensorImage = torch.tensor(image).unsqueeze(0).unsqueeze(0).cuda()
smallTensor = torch.nn.functional.interpolate(tensorImage,size=(128,128,128),mode='nearest')
thresholdedTensor = (smallTensor > 128).squeeze(0).squeeze(0)
flippedTensor = torch.tensor(torch.logical_not(thresholdedTensor),dtype=torch.uint8)
ccTensor = pytorchGetLargest(flippedTensor,ignoreBackground=False,override_comp_id=[flippedTensor[0,0,0].item()])
finalTensor = pytorchBinaryErosion(torch.logical_not(ccTensor).unsqueeze(0).unsqueeze(0))



  flippedTensor = torch.tensor(torch.logical_not(thresholdedTensor),dtype=torch.uint8)


<Image layer 'Image [1]' at 0x7f00c6f33450>

In [20]:
import time
for i in range(10):
    image = testDataset[i]
    tensorImage = torch.tensor(image).unsqueeze(0).unsqueeze(0).cuda()
    numpyStart = time.time()
    smallImage = ndimage.zoom(image,(128/image.shape[0],128/image.shape[1],128/image.shape[2]),order=0)
    thresholdedImage = smallImage > 128
    #eroded = ndimage.binary_erosion(dustedImage,structure=morphology.ball(1))

    flipped = np.logical_not(thresholdedImage)
    background = connectedComponents(flipped)
    final = np.logical_not(background)
    final = ndimage.binary_erosion(final,structure=morphology.ball(2))
    numpyEnd = time.time()
    torchStart = time.time()
    smallTensor = torch.nn.functional.interpolate(tensorImage,size=(128,128,128),mode='nearest')
    thresholdedTensor = (smallTensor > 128).squeeze(0).squeeze(0)
    flippedTensor = torch.tensor(torch.logical_not(thresholdedTensor),dtype=torch.uint8)
    ccTensor = pytorchGetLargest(flippedTensor,ignoreBackground=False,override_comp_id=[flippedTensor[0,0,0].item()])
    #viewer.add_image(ccTensor.cpu().numpy())
    finalTensor = pytorchBinaryErosion(torch.logical_not(ccTensor).unsqueeze(0).unsqueeze(0))
    torchEnd = time.time()
    print(f'Numpy: {numpyEnd-numpyStart}')
    print(f'Torch: {torchEnd-torchStart}')
    print(np.sum(abs(finalTensor.squeeze(0).squeeze(0).cpu().numpy() - final),axis=(0,1,2)))
    

  flippedTensor = torch.tensor(torch.logical_not(thresholdedTensor),dtype=torch.uint8)


Numpy: 0.19847464561462402
Torch: 0.023997068405151367
322887.0
Numpy: 0.1783769130706787
Torch: 0.02505660057067871
712212.0
Numpy: 0.18296170234680176
Torch: 0.024211406707763672
16958.0
Numpy: 0.1741933822631836
Torch: 0.027007579803466797
17880.0
Numpy: 0.1802518367767334
Torch: 0.02519965171813965
18358.0
Numpy: 0.17082643508911133
Torch: 0.02610468864440918
16940.0
Numpy: 0.1807100772857666
Torch: 0.022684097290039062
18131.0
Numpy: 0.17917966842651367
Torch: 0.02402329444885254
18567.0


KeyboardInterrupt: 