In [None]:
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [None]:
def dice_loss(prediction, target, smooth=1e-6):
    """
    Computes the Dice Loss between prediction and target tensors.

    The Dice Loss is a measure of overlap between two samples and is used
    particularly in binary classification tasks like image segmentation.

    Parameters:
        prediction (torch.Tensor): The predicted tensor.
        target (torch.Tensor): The target tensor.
        smooth (float, optional): A smoothing value to avoid division by zero. Default is 1e-6.

    Returns:
        torch.Tensor: The computed Dice Loss between prediction and target tensors.
    """
    intersection = torch.sum(prediction * target)
    cardinality = torch.sum(prediction) + torch.sum(target)
    dice = (2. * intersection + smooth) / (cardinality + smooth)
    return 1. - dice

def fit():
    pass

def predict_mask(model, image, device, transform=None):
    """
    Predicts mask using the provided model and image.

    Parameters:
        model (torch.nn.Module): The trained model for prediction.
        image (torch.Tensor): The input image tensor.
        device (str or torch.device): The device where the model and image should be moved.
        transform (callable, optional): A function/transform to preprocess the input image.
            Default is None.

    Returns:
        tuple: A tuple containing the input image tensor and the predicted mask tensor.
    """
    if transform:
        image_ = transform(image)
    
    image_ =  image.to(device)

    y = model.forward(image_.to(device))
    return image, y

def convert_PIL(images):
    """
    Converts a list of PyTorch tensors to a list of PIL images.

    Parameters:
        images (list of torch.Tensor): A list containing PyTorch image tensors.

    Returns:
        list of PIL.Image.Image: A list containing PIL images converted from the input tensors.
    """
    lst = []
    for image in images:
        to_pil_image = transforms.ToPILImage()
        img = to_pil_image(image)
        lst.append(img)
    return lst

def predict_visulise(truth, prediction_mask, original):
    """
    Visualizes ground truth, original images, predicted masks, and overlaid masks.

    Parameters:
        truth (list of torch.Tensor): A list containing ground truth image tensors.
        prediction_mask (list of torch.Tensor): A list containing predicted mask tensors.
        original (list of torch.Tensor): A list containing original image tensors.

    Raises:
        ValueError: If the lengths of input arrays are not the same.

    Returns:
        None
    """
    # convert tensors back into images
    truth = convert_PIL(truth)
    prediction_mask = convert_PIL(prediction_mask)
    original = convert_PIL(original)

    plot_number = len(truth)

    # Raise error when array lengts are not the same
    if (plot_number != len(prediction_mask)) or (plot_number != len(original)) or (plot_number != len(truth)):
        raise ValueError(f'Expected same length for input arrays, but got {len(truth)}, {len(prediction_mask)}, {len(original)}')

    fig, ax = plt.subplots(plot_number, 4, figsize=(6, plot_number*1.8))

    for i in range(0, plot_number):

        # Set correct pictures of current row
        ax[i][0].imshow(truth[i], cmap='gray')
        ax[i][1].imshow(original[i])
        ax[i][2].imshow(prediction_mask[i], cmap='gray')
        ax[i][3].imshow(original[i]), ax[i][3].imshow(prediction_mask[i], alpha=.5, cmap='gray')
        
        # Turn off axis
        ax[i][0].axis('off'), ax[i][1].axis('off'), ax[i][2].axis('off'), ax[i][3].axis('off') 

    # Set titles for top images
    ax[0][0].set_title('Ground Truth'), ax[0][1].set_title('Original'), ax[0][2].set_title('Predicted'), ax[0][3].set_title('Mask')

    # display Graph
    plt.tight_layout()
    plt.show()

