In [1]:
# data paths Lobke
data_path_training_images = 'training/images'
data_path_training_masks = 'training/mask'
data_path_training_segmentations = 'training/1st_manual'
data_path_secret_test = 'secret_test'


# check if data_path exists:
import os

if not os.path.exists(data_path_training_images):
    print("Please update your images data path to an existing folder.")
else:
    print("Congrats! You selected the correct training images folder :)")

if not os.path.exists(data_path_training_masks):
    print("Please update your masks data path to an existing folder.")
else:
    print("Congrats! You selected the correct training masks folder :)")

if not os.path.exists(data_path_training_segmentations):
    print("Please update your segmentations data path to an existing folder.")
else:
    print("Congrats! You selected the correct training segmentations folder :)")
    
if not os.path.exists(data_path_secret_test):
    print("Please update your secret test set data path to an existing folder.")
else:
    print("Congrats! You selected the correct secret test set folder :)")


# # weights and biases:
# import wandb
# os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
# wandb.login()

Congrats! You selected the correct training images folder :)
Congrats! You selected the correct training masks folder :)
Congrats! You selected the correct training segmentations folder :)
Congrats! You selected the correct secret test set folder :)


In [2]:
# import packages
import glob
import torch
import numpy as np
import monai
from monai.networks.nets import UNet
from monai.transforms import Compose, LoadImaged, ScaleIntensityd, ToTensord, RandFlipd, RandSpatialCropSamplesd, EnsureChannelFirstd, Lambdad, Transposed
from monai.data import DataLoader, CacheDataset, ITKWriter
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
from monai.transforms import Transform
from IPython.display import display, clear_output


In [3]:
# settings for important parameters

# data parameters
num_test_files = 2
num_validation_files = 2
num_train_files = 20 - num_test_files - num_validation_files
# roi_size = 192          # region of interest size
trans_prob = 0.5        # transform probability
batch_size = 2

# model parameters
# num_epochs = 1000 # train for at least 4000 epochs, but no clear overfitting at 10000 epochs
validation_wait = 10

# hyperparameter values for experiments
num_epochssave = [1000, 2000, 4000, 8000, 16000]
# num_epochs = max(num_epochssave)
# lossfunctions = [monai.losses.DiceLoss(sigmoid=True, include_background=False), monai.losses.DiceCELoss(sigmoid=True, include_background=False)]
# lossfunctionsnames = ["Dice", "DiceCE"]
# roi_sizes = [32, 64, 128, 256]

# for testing
num_epochssave = 5
num_epochs = 10
lossfunctions = [monai.losses.DiceLoss(sigmoid=True, include_background=False)]
lossfunctionsnames = ["Dice"]
roi_sizes = [128]


In [4]:
# set random seeds
Seed = 2071293819 # 01111011011101010110111101111011 in binary
monai.utils.set_determinism(seed=Seed) # set seed for model reproducibility

# Set seed for PyTorch
torch.manual_seed(Seed)
torch.cuda.manual_seed(Seed)

# Ensure deterministic behavior pytorch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [5]:
# Load Dataset Paths
image_files = sorted(glob.glob(os.path.join(data_path_training_images, '*.tif')))
mask_files = sorted(glob.glob(os.path.join(data_path_training_masks, '*.gif')))
segmentation_files = sorted(glob.glob(os.path.join(data_path_training_segmentations, '*.gif')))

# Split the dataset into training, validation and test sets
train_files = [{"img": img, "seg": seg} for img, seg in zip(image_files[0:num_train_files], segmentation_files[0:num_train_files])]
validation_files = [{"img": img, "seg": seg} for img, seg in zip(image_files[num_train_files:num_train_files+num_validation_files], segmentation_files[num_train_files:num_train_files+num_validation_files])]
test_files = [{"img": img, "seg": seg} for img, seg in zip(image_files[num_train_files+num_validation_files:], segmentation_files[num_train_files+num_validation_files:])]

def rgb_to_grayscale(img):
    return img[..., 1]  # No need to add a channel dimension here

def visualize_sample(sample, title1=None, title2=None):
    # Visualize the x-ray and overlay the mask, using the dictionary as input
    image = np.squeeze(sample['img'][0,0,:,:])
    mask = np.squeeze(sample['seg'][0,0,:,:])
    fig = plt.figure(figsize=[10,7])
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.imshow(image, 'gray')
    ax2 = fig.add_subplot(1, 2, 2)
    ax2.imshow(image, 'gray')
    overlay_mask = np.ma.masked_where(mask == 0, mask == 1)
    ax2.imshow(overlay_mask, 'Greens', alpha = 0.7, clim=[0,1], interpolation='nearest')
    if title1 is not None:
        ax1.set_title(title1)
    if title2 is not None:
        ax2.set_title(title2)
    plt.show()
    
class ThresholdTransform(Transform): # sets pixels less than a specified threshold to zero.
    def __init__(self, keys, threshold):
        self.keys = keys
        self.threshold = threshold
    def __call__(self, data):
        for key in self.keys:
            img = data[key]
            img[img < self.threshold] = 0
            data[key] = img
        return data

In [6]:
# added to visualize the data after pre-poscessing (after transformations)

def visualize_five_samples_across_batches(data_loader):
    # Prepare a figure to accommodate 5 images and their histograms
    fig, axs = plt.subplots(5, 3, figsize=(18, 30))  # 5 rows (for each image), 3 columns (for image, mask, histogram)
    
    image_count = 0  # Track the number of images processed

    # Iterate over batches until we have displayed 5 images
    for batch in data_loader:
        for i in range(batch['img'].shape[0]):
            if image_count >= 5:
                break  # Stop if we have already processed 5 images

            # Extract image and mask
            image = batch['img'][i]
            mask = batch['seg'][i]

            # Convert to numpy arrays
            image_np = image.numpy().squeeze()
            mask_np = mask.numpy().squeeze()

            # Display the image
            axs[image_count, 0].imshow(image_np, cmap='gray')
            axs[image_count, 0].set_title(f'Image {image_count+1}')
            axs[image_count, 0].axis('off')

            # Display the mask
            axs[image_count, 1].imshow(mask_np, cmap='gray')
            axs[image_count, 1].set_title(f'Mask {image_count+1}')
            axs[image_count, 1].axis('off')

            # Display histogram of pixel values in the image
            axs[image_count, 2].hist(image_np.flatten(), bins=50, color='gray')
            axs[image_count, 2].set_title(f'Histogram of Image {image_count+1} Pixel Values')
            axs[image_count, 2].set_xlabel('Pixel Value')
            axs[image_count, 2].set_ylabel('Frequency')
            axs[image_count, 2].set_xticks(np.arange(0, 1.1, 0.1))
            axs[image_count, 2].grid(True)  # Add a grid to the histogram plot

            image_count += 1  # Increment the count of images processed

        if image_count >= 5:
            break  # Exit the outer loop if 5 images have been processed

    plt.tight_layout()
    plt.show()

In [None]:
# loop over loss functions and roi sizes, and save the resulting model at set epochs
for index, loss_function in enumerate(lossfunctions):
    for roi_size in roi_sizes:
        # Data loaders and transforms
        # Define Transformation
        train_transform = Compose([
            LoadImaged(keys=["img", "seg"], reader='monai.data.ITKReader'),
            EnsureChannelFirstd(keys=["img", "seg"], channel_dim='no_channel'),
            ScaleIntensityd(keys=["img", 'seg']),
            ThresholdTransform(keys=["img"], threshold=0.07),
            ToTensord(keys=["img", "seg"]), 
            Lambdad(keys=["img"], func=lambda x: rgb_to_grayscale(x)),
            RandFlipd(keys=["img", "seg"], prob=trans_prob, spatial_axis=1),
            RandSpatialCropSamplesd(keys=["img", "seg"], roi_size=[roi_size, roi_size], num_samples=1, random_size=False),
        ])

        validation_test_transform = Compose([
            LoadImaged(keys=["img", "seg"], reader='monai.data.ITKReader'),
            EnsureChannelFirstd(keys=["img", "seg"], channel_dim='no_channel'),
            ScaleIntensityd(keys=["img", "seg"]),
            ThresholdTransform(keys=["img"], threshold=0.07),
            ToTensord(keys=["img", "seg"]),
            Lambdad(keys=["img"], func=lambda x: rgb_to_grayscale(x)),
        ])


        # Create CacheDataset
        train_data = CacheDataset(data=train_files, transform=train_transform)
        validation_data = CacheDataset(data=validation_files, transform=validation_test_transform)
        test_data = CacheDataset(data=test_files, transform=validation_test_transform)

        # Create DataLoader
        train_loader = DataLoader(train_data, num_workers=0, batch_size=batch_size, shuffle=True)
        validation_loader = DataLoader(validation_data, num_workers=0, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_data, num_workers=0, batch_size=batch_size, shuffle=True)

        # Model
        model = UNet(
            spatial_dims=2,           #input data is 2D
            in_channels=1,            #input image has 1 channel (if RGB it would be 3)
            out_channels=1,           # binairy segmentation, each pixel is one of 2 things
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),     # strides = how much spatial dimention is reduced
            dropout=0,
        )
        
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        inferer = monai.inferers.SlidingWindowInferer(roi_size=[roi_size, roi_size], overlap=0.5)
        discrete_transform = monai.transforms.AsDiscrete(logit_thresh=0.5, threshold_values=True)     #converts logits to discrete binary values (0 or 1), with a threshold of 0.5
        Sigmoid = torch.nn.Sigmoid()

        # Training loop
        model_loss = np.zeros(num_epochs)
        validation_loss = np.zeros(num_epochs)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        epoch = 0
        with tqdm(range(num_epochs), unit="epochs") as tqdm_iterator:
            tqdm_iterator.set_description('Training{}'.format(''))

            for i, epoch in enumerate(tqdm_iterator):
                model.train()
                training_steps=0
                epoch_loss = 0
                for batch_data in train_loader:
                    training_steps+=1
                    inputs, segmentations = batch_data['img'].to(device), batch_data['seg'].to(device)
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss = loss_function(outputs, segmentations)
                    loss.backward()
                    optimizer.step()
                    epoch_loss += loss.item()
                model_loss[epoch] = epoch_loss/training_steps
                if (epoch+1) % validation_wait == 0:
                    model.eval()
                    validation_steps=0
                    validation_epoch_loss = 0
                    for batch_data in validation_loader:
                        validation_steps+=1
                        inputs, segmentations = batch_data['img'].to(device), batch_data['seg'].to(device)
                        
                        with torch.no_grad():
                            outputs = inferer(inputs.to(device), network=model)

                        loss = loss_function(outputs, segmentations)
                        validation_epoch_loss += loss.item()
                    validation_loss[(epoch-validation_wait+1):epoch+1] = validation_epoch_loss/validation_steps*np.ones(validation_wait)
                    
                    clear_output(wait=True)
                    display(f'Epoch {epoch+1}/{num_epochs}, Loss: {model_loss[epoch]:.4f}, Validation loss: {validation_loss[epoch]:.4f}')
                # if epoch+1 in num_epochssave:
                #     # save model
                #     torch.save(model, f"Experiment results/model_loss{lossfunctionsnames[index]}_roisize{roi_size}_epoch{epoch+1}.pth")
        
        # save loss
        np.save(f"Experiment results/model_loss{lossfunctionsnames[index]}_roisize{roi_size}_modelloss.npy", model_loss)
        np.save(f"Experiment results/model_loss{lossfunctionsnames[index]}_roisize{roi_size}_validation  loss.npy", validation_loss)

Loading dataset: 100%|██████████| 16/16 [00:00<00:00, 32.24it/s]
Loading dataset: 100%|██████████| 2/2 [00:00<00:00, 43.31it/s]
Loading dataset: 100%|██████████| 2/2 [00:00<00:00, 43.42it/s]
Training:  30%|███       | 3/10 [00:00<00:01,  5.79epochs/s]

In [None]:
fig = plt.figure(figsize=(10,10))
ax1 = fig.add_subplot(1, 1, 1)
ax1.plot(np.arange(1,num_epochs+1), model_loss)
ax1.plot(np.arange(1,num_epochs+1), validation_loss)
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.set_title('Training loss Unet')
ax1.legend(['Training loss', 'Validation loss'])
plt.show()

In [None]:
model = torch.load('model_lossDiceCE_roisize128_epoch16000.pth')  # load the model from disk

def visual_evaluation(sample, model):
    model.eval()
    inferer = monai.inferers.SlidingWindowInferer(roi_size=[roi_size, roi_size], overlap=0.5)
    sigmoid = torch.nn.Sigmoid()
    discrete_transform = monai.transforms.AsDiscrete(logit_thresh=0.5, threshold_values=True)

    with torch.no_grad():
        # Perform inference once
        output = inferer(sample['img'].to(device), network=model).cpu()
        # Apply Sigmoid to get probabilities
        probabilities = sigmoid(output)
        # Apply discrete transformation for binary segmentation
        model_segmentation = discrete_transform(probabilities)

    ground_truth_segmentation = sample["seg"][0, 0, :, :].squeeze()

    TP = torch.sum((model_segmentation > 0.9) & (ground_truth_segmentation == 1)).float()
    FN = torch.sum((model_segmentation < 0.1) & (ground_truth_segmentation == 1)).float()
    TN = torch.sum((model_segmentation < 0.1) & (ground_truth_segmentation == 0)).float()
    FP = torch.sum((model_segmentation > 0.9) & (ground_truth_segmentation == 0)).float()
    
    sensitivity = TP / (TP + FN + 1e-6)
    specificity = TN / (TN + FP + 1e-6)
    accuracy = (TP + TN) / (TP + TN + FP + FN + 1e-6)
    precision = TP / (TP + FP + 1e-6)
    
    fig, ax = plt.subplots(2, 2, figsize=[10, 10])

    # Image
    img_np = sample["img"][0, 0, :, :].squeeze().numpy()
    ax[0, 0].imshow(img_np, cmap='gray')
    ax[0, 0].set_title('Image')

    # Ground truth overlay
    ground_truth_segmentation_np = ground_truth_segmentation.numpy()
    ax[0, 1].imshow(ground_truth_segmentation_np, cmap='gray')
    ax[0, 1].set_title('Ground Truth Segmentation')

    # Prediction overlay
    model_segmentation_np = model_segmentation[0, 0].numpy()
    ax[1, 0].imshow(model_segmentation_np, cmap='gray')
    ax[1, 0].set_title('Prediction')
    
    # Prediction with false positives and false negatives
    ax[1, 1].imshow(img_np, cmap='gray')
    
    # False positive overlay
    false_positives = (model_segmentation_np > 0.9) & (ground_truth_segmentation_np == 0)
    false_positive_overlay = np.zeros_like(model_segmentation_np)
    false_positive_overlay[false_positives] = 1

    # False negative overlay
    false_negatives = (model_segmentation_np < 0.1) & (ground_truth_segmentation_np == 1)
    false_negative_overlay = np.zeros_like(model_segmentation_np)
    false_negative_overlay[false_negatives] = 1

    # Create RGB image for overlay
    overlay = np.zeros((img_np.shape[0], img_np.shape[1], 3), dtype=np.float32)

    # Add red for false positives
    overlay[..., 0] = false_positive_overlay

    # Add green for false negatives
    overlay[..., 1] = false_negative_overlay

    ax[1, 1].imshow(overlay, alpha=0.7)
    ax[1, 1].set_title('Prediction with False Positives and False Negatives')

    plt.tight_layout()
    plt.show()

    print("Sensitivity is {:.4f}".format(sensitivity.item()))
    print("Specificity is {:.4f}".format(specificity.item()))
    print("Accuracy is {:.4f}".format(accuracy.item()))
    print("Precision is {:.4f}".format(precision.item()))
    
    print("Loss is {}".format(loss_function(output, sample['seg'])))

def compute_metric(dataloader, model, metric_fn):
    """
    This function computes the average value of a metric for a data set.
    
    Args:
        dataloader (monai.data.DataLoader): dataloader wrapping the dataset to evaluate.
        model (torch.nn.Module): trained model to evaluate.
        metric_fn (function): function computing the metric value from two tensors:
            - a batch of outputs,
            - the corresponding batch of ground truth masks.
        
    Returns:
        (float) the mean value of the metric
    """
    model.eval()
    inferer = monai.inferers.SlidingWindowInferer(roi_size=[roi_size, roi_size], overlap=0.5)
    discrete_transform = monai.transforms.AsDiscrete(threshold=0.5)
    Sigmoid = torch.nn.Sigmoid()
    
    mean_value = 0       # initialize
    
    for sample in dataloader:
        with torch.no_grad():
            output = discrete_transform(Sigmoid(inferer(sample['img'].to(device), network=model).cpu()))
        mean_value += metric_fn(output, sample["seg"])
    return (torch.mean(mean_value) / len(dataloader)).item()

# Assuming test_loader, test_data, model, roi_size, device, and loss_function are defined
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=0)
visual_evaluation(next(iter(test_loader)), model)
visual_evaluation(next(iter(test_loader)), model)


dice_metric = monai.metrics.DiceMetric(include_background=True, reduction="mean")

print('Mean Dice score is {}'.format(compute_metric(test_loader, model, dice_metric))) # This should return the Dice score on the test set


In [None]:
# save model
# torch.save(model, f'model_epoch{num_epochs}_loss_Dice.pth')	# save the model to disk

In [None]:
# Compute dice loss for all models
# get all model names
model_names = glob.glob('Experiment results/*.pth')
# sort model names 
model_names.sort(key=lambda x: os.path.getmtime(x))
print(model_names)
# create array of model names
model_names = np.array(model_names)
# reshape model_names to 5x8
model_names = model_names.reshape((8,5))
# print(model_names)
model_names_sorted = model_names.copy()
model_names_sorted[2:,:]=model_names_sorted[1:-1,:]
model_names_sorted[1,:]=model_names[-1,:]
print(model_names_sorted)
roi_sizes = [32, 64, 128, 256, 32, 64, 128, 256]
# calculate dice loss for all models on test set
dice_losses = np.zeros((8,5))
for ii1, model_name_list in enumerate(model_names_sorted):
    for ii2,model_name in enumerate(model_name_list):
        model = torch.load(model_name)
        model.eval()
        inferer = monai.inferers.SlidingWindowInferer(roi_size=[roi_sizes[ii1], roi_sizes[ii1]], overlap=0.5)
        discrete_transform = monai.transforms.AsDiscrete(threshold=0.5)
        Sigmoid = torch.nn.Sigmoid()
        
        mean_value = 0       # initialize
        
        for sample in test_loader:
            with torch.no_grad():
                output = discrete_transform(Sigmoid(inferer(sample['img'].to(device), network=model).cpu()))
            mean_value += dice_metric(output, sample["seg"])
        dice_losses[ii1,ii2] = (torch.mean(mean_value) / len(test_loader)).item()

# save dice losses
np.save('Experiment results/dice_losses.npy', dice_losses)

In [None]:
# create figure of dice score vs epoch for all models
fig = plt.figure()
ax = fig.add_subplot(111)
colors = ['b', 'g', 'r', 'c']
print(dice_losses)

for i in range(len(model_names)):
    if i>3:
        ax.plot(num_epochssave, dice_losses[i], label=f"{model_names[i]}", linestyle = 'dashed', color = colors[i-4], linewidth=2)
        # ax.plot(num_epochssave, dice_losses[i], label=f"{model_names[i]}", linestyle = 'dashed', color = colors[i-4], linewidth=2)
    else:
        ax.plot(num_epochssave, dice_losses[i], label=f"{model_names[i]}", color = colors[i], linewidth=2)
        # ax.plot(num_epochssave, dice_losses[i], label=f"{model_names[i]}", color = colors[i], linewidth=2)
ax.set_xlabel('Epochs')
ax.set_ylabel('Dice score')
ax.legend(['Dice_32','Dice_64','Dice_128','Dice_256','DiceCE_32','DiceCE_64','DiceCE_128','DiceCE_256'])
plt.show()

In [None]:
def test_secret(sample, model, samplenr, roi_size):

    model.eval()
    inferer = monai.inferers.SlidingWindowInferer(roi_size=[roi_size, roi_size], overlap=0.5)
    discrete_transform = monai.transforms.AsDiscrete(logit_thresh=0.5, threshold_values=True)
    Sigmoid = torch.nn.Sigmoid()
    with torch.no_grad():
        output = discrete_transform(Sigmoid(inferer(sample['img'].to(device), network=model).cpu()))
 
    fig, ax = plt.subplots(1, 2)  
    
    ax[0].imshow(sample["img"][0,0,:,:].squeeze(), 'gray')   
    ax[0].set_title('Image')

    overlay_output = np.ma.masked_where(output[0, 0] == 0, output[0, 0])
    print(overlay_output.shape)
    ax[1].imshow(overlay_output[:,:], 'Reds', alpha = 1, clim=[0,1])
    ax[1].set_title('Prediction')
    plt.show()

    
    
    # save the output
    savepath = f'testresults\\{samplenr}_secret.nii.gz'
    writer = ITKWriter(output_dtype=output.dtype)

    # Save the MetaTensor to a file
    writer.set_data_array(output)
    writer.write(savepath)

secret_test_transform = Compose([
            LoadImaged(keys=["img"], reader='monai.data.ITKReader'),
            EnsureChannelFirstd(keys=["img"], channel_dim='no_channel'),
            ScaleIntensityd(keys=["img"]),
            ThresholdTransform(keys=["img"], threshold=0.07),
            ToTensord(keys=["img"]),
            Lambdad(keys=["img"], func=lambda x: rgb_to_grayscale(x)),
])

In [None]:
# secret test set:
secret_test = sorted(glob.glob(os.path.join(data_path_secret_test, '*.tif')))
print('secret test = ', secret_test)

secret_test_files = [{"img": img} for img in zip(secret_test)]
test_secret_data = CacheDataset(data=secret_test_files, transform=secret_test_transform)
test_secret_loader = DataLoader(test_secret_data, num_workers=0, batch_size=1, shuffle=False)

# load best model
model = torch.load('Experiment results/model_lossDiceCE_roisize128_epoch16000.pth') # check which is best 
roi_size = 128
for number, batch in enumerate(test_secret_loader):
    print(batch["img"].shape)
    test_secret(batch, model, samplenr = number+41, roi_size=128)