In [1]:
from pipeline import segmentationPipeline
import numpy as np
import napari
import os
import matplotlib.pyplot as plt
import torch

In [2]:
pipeline = segmentationPipeline(device='cuda:0',weightPathOverrides=[None,'/home/aaronluong/Documents/GitHub/segmentation-pipeline/weights/17_cancer_LUNA16_R_weights.pt','/home/aaronluong/Documents/GitHub/segmentation-pipeline/weights/17_cancer_LUNA16_L_weights.pt'])
#pipeline = segmentationPipeline(device='cuda:0')

In [3]:
def one_hot_encode(tensor, num_classes):
    """
    Convert tensor with indices (after argmax) to one-hot encoded tensor.
    """
    tensor = tensor.long()  # Ensure tensor is of dtype long
    shape = tensor.shape
    one_hot = torch.zeros(*shape, num_classes, dtype=torch.float32, device=tensor.device)
    return one_hot.scatter_(len(shape), tensor.unsqueeze(-1), 1.0)
def dice_per_class(y_true, y_pred, num_classes, eps=1e-7):
    """
    Calculate Dice Loss for each class.
    Parameters:
    - y_true: Ground truth masks with shape [batch_size, h, w, d].
    - y_pred: Predicted probabilities with the same shape as y_true.
    - num_classes: Total number of classes.
    - eps: Small value to avoid division by zero.
    Returns:
    - torch tensor of shape [num_classes], with dice loss per class.
    """
    y_true_one_hot = one_hot_encode(y_true, num_classes).squeeze(0).squeeze(0)
    y_pred_one_hot = one_hot_encode(y_pred,num_classes).squeeze(0).squeeze(0)  # Convert from probabilities to one-hot format
    intersection = torch.sum(y_true_one_hot * y_pred_one_hot, dim=(0, 1, 2))
    union = torch.sum(y_true_one_hot, dim=(0, 1, 2)) + torch.sum(y_pred_one_hot, dim=(0, 1, 2))
    dice_coeff = (2. * intersection + eps) / (union + eps)
    return dice_coeff
def average_dice(y_true, y_pred, num_classes, eps=1e-7):
    """
    Calculate the average Dice Loss over non-zero classes.
    Parameters:
    - y_true, y_pred, num_classes, eps: As before.
    Returns:
    - float, average dice loss over non-zero classes.
    """
    per_class_dice = dice_per_class(y_true, y_pred, num_classes, eps)
    # We'll exclude the zeroth class (often background) when averaging
    return torch.mean(per_class_dice[1:])

In [4]:
def displayDice(results):
    results = list(results)
    print(f'Average: {(sum(results[1:])/len(results[1:])):.4f} LUL: {results[1]:.4f} LLL: {results[2]:.4f} RUL: {results[3]:.4f} RML: {results[4]:.4f} RLL: {results[5]:.4f}')

In [7]:
for file in os.listdir('/home/aaronluong/gabiSegmentationProject/test cases/LUNA16/test_images'):
    img = np.load(os.path.join('/home/aaronluong/gabiSegmentationProject/test cases/LUNA16/test_images',file))
    img = img + 1000
    img = np.clip(img,0,1000)
    img = np.flip(img,1)
    mask = np.load(os.path.join('/home/aaronluong/gabiSegmentationProject/test cases/LUNA16/test_masks',file.replace('image.npy','LobeSegmentation.npy')))
    mask = np.where(mask > 6, mask -6, mask) #7-8 -> 1-2
    mask = np.where(mask > 3, mask -1, mask) #4-6 -> 3-5
    mask = np.flip(mask,1)
    mask = np.array(mask)
    out = pipeline.segment(img,getLR=False,takeLargest=True)

    displayDice(dice_per_class(torch.tensor(mask).cuda(),out,num_classes=6))
    #break
    np.save(file.replace('.npy','_seg21524.npy'),out.squeeze(0).squeeze(0).cpu().numpy())
# viewer = napari.view_image(mask,name='gt',colormap='gist_earth',contrast_limits=(0,5))
# viewer.add_image(out.squeeze(0).squeeze(0).cpu().numpy(),colormap='gist_earth',name='segmentation',contrast_limits=(0,5))

Average: 0.9532 LUL: 0.9792 LLL: 0.9652 RUL: 0.9593 RML: 0.9436 RLL: 0.9187
Average: 0.9406 LUL: 0.9682 LLL: 0.9389 RUL: 0.9420 RML: 0.9204 RLL: 0.9333
Average: 0.7344 LUL: 0.9791 LLL: 0.9646 RUL: 0.8342 RML: 0.0093 RLL: 0.8848
Average: 0.9442 LUL: 0.9624 LLL: 0.9688 RUL: 0.9522 RML: 0.8835 RLL: 0.9542
Average: 0.9218 LUL: 0.9727 LLL: 0.9700 RUL: 0.8908 RML: 0.8206 RLL: 0.9547
Average: 0.9324 LUL: 0.9704 LLL: 0.9256 RUL: 0.9532 RML: 0.8821 RLL: 0.9309
Average: 0.9493 LUL: 0.9793 LLL: 0.9627 RUL: 0.9604 RML: 0.8905 RLL: 0.9538
Average: 0.9408 LUL: 0.9439 LLL: 0.9294 RUL: 0.9501 RML: 0.9325 RLL: 0.9480
Average: 0.9401 LUL: 0.9410 LLL: 0.9293 RUL: 0.9498 RML: 0.9259 RLL: 0.9543
Average: 0.7257 LUL: 0.9406 LLL: 0.9361 RUL: 0.6637 RML: 0.1524 RLL: 0.9355


In [None]:
# viewer.add_image(img,name='image',colormap='gray',contrast_limits=(0,1000))

<Image layer 'image' at 0x7fda04613100>