In [None]:
import os
import shutil
from os.path import join

def create_folder_structure_for_segmentation(images_folder, labels_folder, dir_path, num_train_images = 76, num_val_images = 25):
    # Get list of filenames in both folders
    image_files = os.listdir(images_folder)
    label_files = os.listdir(labels_folder)

    # Sort the filenames to ensure correspondence
    image_files.sort()
    label_files.sort()

    # Calculate the number of files for training
    # num_train = int(train_ratio * len(image_files))
    # num_val = int((train_ratio+val_ratio) * len(image_files))

    # Create train and validation folders if not exists
    os.makedirs(dir_path)

    os.makedirs(join(dir_path, "train"))
    os.makedirs(join(dir_path, "val"))
    os.makedirs(join(dir_path, "test"))

    train_images_path =join(dir_path, "train", "images")
    val_images_path = join(dir_path, "val","images")
    train_segmentaion_path = join(dir_path, "train", "segmentations")
    val_segmentations_path = join(dir_path, "val", "segmentations")
    test_images_path =join(dir_path, "test", "images")
    test_segmentations_path =join(dir_path, "test", "segmentations")


    os.makedirs(train_images_path)
    os.makedirs(val_images_path)
    os.makedirs(train_segmentaion_path)
    os.makedirs(val_segmentations_path)
    os.makedirs(test_images_path)
    os.makedirs(test_segmentations_path)

    # Copy images to train folder
    for image_file in image_files[:num_train_images]:
        src_image = os.path.join(images_folder, image_file)
        dst_image = os.path.join(train_images_path, image_file)

        shutil.copy(src_image, dst_image)

    # Copy labels to train folder
    for label_file in label_files[:num_train_images]:
        src_label = os.path.join(labels_folder, label_file)
        dst_label = os.path.join(train_segmentaion_path, label_file)
        shutil.copy(src_label, dst_label)

    # Copy remaining images to validation folder
    for image_file in image_files[num_train_images:num_train_images + num_val_images]:
        src_image = os.path.join(images_folder, image_file)
        dst_image = os.path.join(val_images_path, image_file)
        shutil.copy(src_image, dst_image)

    # Copy remaining labels to validation folder
    for label_file in label_files[num_train_images:num_train_images + num_val_images]:
        src_label = os.path.join(labels_folder, label_file)
        dst_label = os.path.join(val_segmentations_path, label_file)
        shutil.copy(src_label, dst_label)


    # Copy remaining images to validation folder
    for image_file in image_files[num_train_images + num_val_images:]:
        src_image = os.path.join(images_folder, image_file)
        dst_image = os.path.join(test_images_path, image_file)
        shutil.copy(src_image, dst_image)

    # Copy remaining labels to validation folder
    for label_file in label_files[num_train_images + num_val_images:]:
        src_label = os.path.join(labels_folder, label_file)
        dst_label = os.path.join(test_segmentations_path, label_file)
        shutil.copy(src_label, dst_label)


In [None]:
create_folder_structure_for_segmentation(join("data","cropped_segmentation_data", "data","trans"), join("data","cropped_segmentation_data", "references","trans"), join("segmentation", "data", "cropped_fully_annotated"))

In [None]:
import torch
import torchvision.models as models

# Create a ResNet-18 model
resnet18 = models.resnet18()

# Print the layers of the ResNet-18 model
for name, module in resnet18.named_children():
    print(f"{name}: {module}")

In [None]:
import torch
import torchvision.models as models

# Create a ResNet-18 model
resnet18 = models.resnet34()

# Print the layers of the ResNet-18 model
for name, module in resnet18.named_children():
    print(f"{name}: {module}")

In [None]:
from PIL import Image
path = "C://Users//ondra//school//Diplomka//diploma-thesis//data//new_segmentation_data//data//trans//p_201501061017320170VAS.png"
img = Image.open(path)

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

rgb_image = Image.open(path)
grayscale_image = rgb_image.convert('L')

# Plot the RGB image
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title('RGB Image')
plt.imshow(rgb_image)
plt.axis('off')

# Plot the grayscale image
plt.subplot(1, 2, 2)
plt.title('Grayscale Image')
plt.imshow(grayscale_image, cmap='gray')
plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
from PIL import Image
import numpy as np

# Load the image
image_path = path
rgb_image = Image.open(image_path)
grayscale_image = rgb_image.convert('L')

# Get the size of the image
width, height = rgb_image.size

# Calculate the coordinates for the 3x3 pixel region in the middle
x_center = width // 2
y_center = height // 2
x_start = x_center - 1
x_end = x_center + 2
y_start = y_center - 1
y_end = y_center + 2

# Extract the 3x3 pixel region from both images
rgb_pixels = np.array(rgb_image)
grayscale_pixels = np.array(grayscale_image)

rgb_patch = rgb_pixels[y_start:y_end + 1, x_start:x_end + 1, :]
grayscale_patch = grayscale_pixels[y_start:y_end + 1, x_start:x_end + 1]

# Print the RGB pixel values
print("RGB Pixel Values:")
for i in range(3):
    for j in range(3):
        print(f"RGB ({x_start + j}, {y_start + i}): {rgb_patch[i, j]}")

# Print the grayscale pixel values
print("\nGrayscale Pixel Values:")
for i in range(3):
    for j in range(3):
        print(f"Grayscale ({x_start + j}, {y_start + i}): {grayscale_patch[i, j]}")


In [None]:
import torch
import torch.nn.functional as F

def w_nll_loss(prediction, label, n_classes = False):
    # eps = 0.001
    weights = []
    total_pixels = label.numel()
    if not n_classes:
        n_classes = prediction.shape[1]
    for i in range(n_classes):
        # BY PAPER IN CODE DIFFERENT
        weights.append(1-(torch.sum(label == i).item()/total_pixels))
    print("w ",weights)
    # weights = torch.ones(n_classes)
    # one_hot_encoded_label = one_hot(label, n_classes)
    weighted_cross_entropy = F.nll_loss(torch.log(prediction), label, weight=torch.tensor(weights).to(prediction.device))
    return weighted_cross_entropy

def weighted_cross_entropy_loss(outputs, targets, num_classes = 4):
    """
    Custom cross-entropy loss for already softmaxed outputs, 
    ignoring targets with a label of zero, and adding class weights.
    
    :param outputs: Softmax probabilities of each class at each pixel, 
                    shape [batch_size, num_classes, height, width]
    :param targets: Ground truth labels, shape [batch_size, height, width]
    :param num_classes: Number of classes
    :return: Loss value
    """
    # Ensure the outputs are valid probabilities
    outputs = torch.clamp(outputs, min=1e-7, max=1 - 1e-7)  # Prevent log(0)
    
    # Create a mask for non-zero target labels
    # mask = (targets != 0).float()

    # Calculate class weights
    class_weights = []
    total_pixels = float(targets.numel())  # Total number of labeled pixels
    for c in range(0, num_classes):  # Start from 1 to ignore class 0
        class_pixels = float((targets == c).sum())
        class_weight = 1 - (class_pixels / total_pixels)
        class_weights.append(class_weight)
    class_weights = torch.tensor(class_weights).to(outputs.device)
    # class_weights = torch.ones(num_classes)
    # Gather the probabilities for the correct classes and apply class weights
    gathered_probs = outputs.gather(1, targets.unsqueeze(1)).squeeze(1)
    log_probs = torch.log(gathered_probs)
    print("w ",class_weights)
    weighted_log_probs = log_probs * class_weights[targets]  # Adjust class index for zero-indexing
    # Apply mask and compute the loss
    loss = - (weighted_log_probs).sum() / total_pixels

    return loss

def weighted_partial_cross_entropy_loss(outputs, targets, num_classes = 4):
    """
    Custom cross-entropy loss for already softmaxed outputs, 
    ignoring targets with a label of zero, and adding class weights.
    
    :param outputs: Softmax probabilities of each class at each pixel, 
                    shape [batch_size, num_classes, height, width]
    :param targets: Ground truth labels, shape [batch_size, height, width]
    :param num_classes: Number of classes
    :return: Loss value
    """
    # Ensure the outputs are valid probabilities
    outputs = torch.clamp(outputs, min=1e-7, max=1 - 1e-7)  # Prevent log(0)
    
    # Create a mask for non-zero target labels
    mask = (targets != 0).float()

    # Calculate class weights
    class_weights = [0]
    total_pixels = float(mask.sum())  # Total number of labeled pixels
    for c in range(1, num_classes):  # Start from 1 to ignore class 0
        class_pixels = float((targets == c).sum())
        class_weight = 1 - (class_pixels / total_pixels)
        class_weights.append(class_weight)
    class_weights = torch.tensor(class_weights).to(outputs.device)

    # Gather the probabilities for the correct classes and apply class weights
    gathered_probs = outputs.gather(1, targets.unsqueeze(1)).squeeze(1)
    print(gathered_probs.shape)
    print(gathered_probs)
    print(class_weights.shape)
    log_probs = torch.log(gathered_probs)
    print("before w")
    print(log_probs)
    print("w")
    print(class_weights)
    weighted_log_probs = log_probs * class_weights[targets]  # Adjust class index for zero-indexing
    print("afyet w")
    print(weighted_log_probs)
    # Apply mask and compute the loss
    loss = - (weighted_log_probs * mask).sum() / mask.sum()

    return loss

def custom_cross_entropy_loss_without_softmax(outputs, targets):
    """
    Custom cross-entropy loss for already softmaxed outputs, 
    ignoring targets with a label of zero.
    
    :param outputs: Softmax probabilities of each class at each pixel, 
                    shape [batch_size, num_classes, height, width]
    :param targets: Ground truth labels, shape [batch_size, height, width]
    :return: Loss value
    """
    # Ensure the outputs are softmax probabilities
    outputs = torch.clamp(outputs, min=1e-7, max=1 - 1e-7)  # Prevent log(0)
    
    # Create a mask for non-zero target labels
    mask = (targets != 0).float()

    # Gather the probabilities for the correct classes
    gathered_probs = outputs.gather(1, targets.unsqueeze(1)).squeeze(1)

    # Compute the log of probabilities
    log_probs = torch.log(gathered_probs)

    # Apply mask and compute the loss
    loss = - (log_probs * mask).sum() / mask.sum()

    return loss

def custom_nll_loss(outputs, targets):
    mask = targets != 0
    print("mask:",mask)
    outputs = outputs[torch.arange(outputs.shape[0]).unsqueeze(1), targets] * mask
    loss = -outputs.sum() / mask.sum()
    return loss

# Mock data for testing
# Let's assume we have a batch size of 1, 3 classes, and a 2x2 image
batch_size, num_classes, height, width = 1, 4, 2, 2

# Create random logits and apply log_softmax to simulate model outputs
logits = torch.randn(batch_size, num_classes, height, width)
outputs = F.softmax(logits)
# Create target labels (ground truth)
# Let's have one unlabeled pixel (with label 0) and others with labels
targets = torch.tensor([[[1, 1], [2, 1]]])

# Print inputs
print("Model Outputs (Log Probabilities):")
print(outputs)
print("\nTarget Labels:")
print(targets)

# # Compute and print the custom loss
# loss = custom_nll_loss(outputs, targets)
# print("\nComputed Loss:")
# print(loss)


# loss = F.cross_entropy(logits, targets)
# print("\nComputed Loss:")
# print(loss)

# loss = weighted_partial_cross_entropy_loss(outputs, targets)
# print("\nComputed Loss:")
# print(loss)

loss = weighted_cross_entropy_loss(outputs, targets)
print("\nComputed Loss:")
print(loss)

loss = w_nll_loss(outputs, targets)
print("\nComputed Loss:")
print(loss)

In [None]:
import torch
import torch.nn.functional as F

import torch

def weighted_cross_entropy_loss(outputs, targets, class_weights):
    """
    Weighted cross entropy loss for already softmaxed outputs.

    :param outputs: Softmax probabilities from the neural network,
                    shape [batch_size, num_classes, height, width].
    :param targets: Ground truth labels, shape [batch_size, height, width].
    :param class_weights: Tensor of shape [num_classes] with class weights.
    :return: Weighted cross entropy loss.
    """
    # Ensure class_weights is a tensor and move to the same device as outputs
    class_weights = torch.tensor(class_weights, dtype=outputs.dtype).to(outputs.device)

    # Reshape targets to use for gather
    targets = targets.unsqueeze(1)  # shape becomes [batch_size, 1, height, width]

    # Gather the probabilities of the true classes
    true_class_probs = torch.gather(outputs, 1, targets).squeeze(1)

    # Apply weights to the gathered probabilities
    weights = class_weights[targets].squeeze(1)  # shape becomes [batch_size, height, width]

    # Compute the loss
    loss = -torch.log(true_class_probs + 1e-6) * weights  # Adding epsilon for numerical stability
    return loss.mean()

# # Example usage
# batch_size, num_classes, height, width = 1, 3, 4, 4
# outputs = torch.softmax(torch.randn(batch_size, num_classes, height, width), dim=1)  # Example softmaxed outputs
# targets = torch.randint(0, num_classes, (batch_size, height, width))  # Example targets
# class_weights = torch.tensor([1.0, 2.0, 0.5])  # Example class weights

# loss = weighted_cross_entropy_loss(outputs, targets, class_weights)
# print(loss)


def weighted_cross_entropy_for_softmaxed_output(outputs, targets, class_weights):
    """
    Weighted cross entropy loss for already softmaxed outputs.

    :param outputs: Softmax probabilities from the neural network,
                    shape [batch_size, num_classes, height, width].
    :param targets: Ground truth labels, shape [batch_size, height, width].
    :param class_weights: Tensor of shape [num_classes] with class weights.
    :return: Weighted cross entropy loss.
    """
    # Convert softmax outputs to log probabilities
    log_probs = torch.log(outputs + 1e-6)  # Adding epsilon for numerical stability

    # Compute the weighted NLL loss
    loss = F.nll_loss(log_probs, targets, weight=class_weights, reduction='mean')

    return loss

# Example usage
batch_size, num_classes, height, width = 1, 3, 2, 2
o = torch.randn(batch_size, num_classes, height, width)
outputs = torch.softmax(o, dim=1)  # Example softmaxed outputs
targets = torch.randint(0, num_classes, (batch_size, height, width))  # Example targets

class_weights = torch.tensor([2.0, 1.0, 1.0])  # Example class weights
loss = weighted_cross_entropy_for_softmaxed_output(outputs, targets, class_weights)
print(loss)

loss = weighted_cross_entropy_loss(outputs, targets, class_weights)
print(loss)

loss = F.cross_entropy(o, targets, class_weights)
print(loss)

In [None]:
import torch



def wce_already_sofmaxed(outputs, targets):
    """
    Weighted cross entropy loss for already softmaxed outputs.

    :param outputs: Softmax probabilities from the neural network,
                    shape [batch_size, num_classes, height, width].
    :param targets: Ground truth labels, shape [batch_size, height, width].
    :param class_weights: Tensor of shape [num_classes] with class weights.
    :return: Weighted cross entropy loss.
    """
    # Convert softmax outputs to log probabilities
    class_weights = []
    total_pixels = float(targets.numel())  # Total number of labeled pixels
    for c in range(0, num_classes):  # Start from 1 to ignore class 0
        class_pixels = float((targets == c).sum())
        class_weight = 1 - (class_pixels / total_pixels)
        class_weights.append(class_weight)
    class_weights = torch.tensor(class_weights).to(outputs.device)
    log_probs = torch.log(outputs + 1e-6)  # Adding epsilon for numerical stability
    loss = F.nll_loss(log_probs, targets, weight=class_weights, reduction='mean')

    return loss

# def weighted_cross_entropy_loss(outputs, targets, class_weights):
#     """
#     Weighted cross entropy loss for already softmaxed outputs, 
#     computed only for annotated pixels.

#     :param outputs: Softmax probabilities from the neural network,
#                     shape [batch_size, num_classes, height, width].
#     :param targets: Ground truth labels, shape [batch_size, height, width].
#                     Annotated pixels are non-zero.
#     :param class_weights: Tensor of shape [num_classes] with class weights.
#     :return: Weighted cross entropy loss.
#     """
#     # Ensure class_weights is a tensor and move to the same device as outputs
#     class_weights = torch.tensor(class_weights, dtype=outputs.dtype).to(outputs.device)

#     # Create a mask for annotated pixels (non-zero in targets)
#     mask = targets != 0

#     # Reshape targets to use for gather, and apply the mask
#     targets_masked = targets.unsqueeze(1) * mask.unsqueeze(1)

#     # Gather the probabilities of the true classes for annotated pixels
#     true_class_probs = torch.gather(outputs, 1, targets_masked).squeeze(1)

#     # Apply class weights
#     weights = class_weights[targets] * mask

#     # Compute the loss, ignore unannotated pixels by masking
#     loss = -torch.log(true_class_probs + 1e-6) * weights  # Adding epsilon for numerical stability
#     return loss[mask].mean()  # Average over only the annotated pixels

def pwce_already_sofmaxed(outputs, targets):
    """
    Partial Weighted NLL loss for softmaxed outputs, computed only on annotated pixels.

    :param outputs: Softmax probabilities from the neural network,
                    shape [batch_size, num_classes, height, width].
    :param targets: Ground truth labels, shape [batch_size, height, width].
                    Annotated pixels are non-zero.
    :param class_weights: Tensor of shape [num_classes] with class weights.
    :return: Weighted NLL loss.
    """
    mask = targets != 0

    # Convert softmax outputs to log probabilities
    class_weights = [0]
    total_pixels = float(mask.sum())  # Total number of labeled pixels
    for c in range(1, num_classes):  # Start from 1 to ignore class 0
        class_pixels = float((targets == c).sum())
        class_weight = 1 - (class_pixels / total_pixels)
        class_weights.append(class_weight)
    class_weights = torch.tensor(class_weights).to(outputs.device)

    log_probs = torch.log(outputs + 1e-6)  # Add epsilon for numerical stability

    # Create a mask for annotated pixels (non-zero in targets)
    mask = targets != 0

    # Apply the mask to targets
    masked_targets = targets * mask


    # print(log_probs.shape)
    # print(masked_targets.shape)
    # Compute the NLL loss
    loss = F.nll_loss(log_probs, masked_targets, weight=class_weights, reduction='mean')
    return loss
    # print(loss)
    # Apply the mask to the loss and compute the mean loss over annotated pixels
    # loss = loss * mask
    # return loss[mask].mean()


def pwce_already_sofmaxed2(outputs, targets):
    """
    Partial Weighted NLL loss for softmaxed outputs, computed only on annotated pixels.

    :param outputs: Softmax probabilities from the neural network,
                    shape [batch_size, num_classes, height, width].
    :param targets: Ground truth labels, shape [batch_size, height, width].
                    Annotated pixels are non-zero.
    :param class_weights: Tensor of shape [num_classes] with class weights.
    :return: Weighted NLL loss.
    """
    mask = targets != 0

    class_weights = [0]
    total_pixels = float(mask.sum())  # Total number of labeled pixels
    for c in range(1, num_classes):  # Start from 1 to ignore class 0
        class_pixels = float((targets == c).sum())
        class_weight = 1 - (class_pixels / total_pixels)
        class_weights.append(class_weight)
    class_weights = torch.tensor(class_weights).to(outputs.device)

    log_probs = torch.log(outputs + 1e-6)  # Add epsilon for numerical stability

    loss = F.nll_loss(log_probs, targets, weight=class_weights, reduction='mean')
    return loss
    # print(loss)
    # Apply the mask to the loss and compute the mean loss over annotated pixels
    # loss = loss * mask
    # return loss[mask].mean()



# def weighted_cross_entropy_loss_fuull(outputs, targets, class_weights):
#     """
#     Weighted cross entropy loss for already softmaxed outputs.

#     :param outputs: Softmax probabilities from the neural network,
#                     shape [batch_size, num_classes, height, width].
#     :param targets: Ground truth labels, shape [batch_size, height, width].
#     :param class_weights: Tensor of shape [num_classes] with class weights.
#     :return: Weighted cross entropy loss.
#     """
#     # Ensure class_weights is a tensor and move to the same device as outputs
#     class_weights = torch.tensor(class_weights, dtype=outputs.dtype).to(outputs.device)

#     # Reshape targets to use for gather
#     targets = targets.unsqueeze(1)  # shape becomes [batch_size, 1, height, width]

#     # Gather the probabilities of the true classes
#     true_class_probs = torch.gather(outputs, 1, targets).squeeze(1)

#     # Apply weights to the gathered probabilities
#     weights = class_weights[targets].squeeze(1)  # shape becomes [batch_size, height, width]

#     # Compute the loss
#     loss = -torch.log(true_class_probs + 1e-6) * weights  # Adding epsilon for numerical stability
#     print(loss)
#     return loss.mean()
def w_nll_loss(prediction, label, n_classes = False):
    # eps = 0.001
    weights = []
    total_pixels = label.numel()
    if not n_classes:
        n_classes = prediction.shape[1]
    for i in range(n_classes):
        # BY PAPER IN CODE DIFFERENT
        weights.append(1-(torch.sum(label == i).item()/total_pixels))
    
    # one_hot_encoded_label = one_hot(label, n_classes)
    weighted_cross_entropy = F.nll_loss(torch.log(prediction), label, weight=torch.tensor(weights).to(prediction.device))
    return weighted_cross_entropy


# # Example usage
batch_size, num_classes, height, width = 1, 4, 100, 100
o  = torch.randn(batch_size, num_classes, height, width)
outputs = torch.softmax(o, dim=1)  # Example softmaxed outputs
targets = torch.randint(0, num_classes, (batch_size, height, width))  # Example targets with some zeros
# targets = torch.tensor([[[0, 1], [2, 2]]])


loss = wce_already_sofmaxed(outputs, targets)
print(loss)
loss = pwce_already_sofmaxed(outputs, targets)
print(loss)
loss = pwce_already_sofmaxed2(outputs, targets)
print(loss)




In [None]:
import torch
import torch.nn.functional as F

def dice_loss(predictions, labels, include_background = False, apply_softamax = False):
    """compute dice loss as a mean dice score across all classes

    Args:
        predictions (tensor(batch, num_clases, height, width)): predictions from segmentation models 0 to 1 probabilites for class
        labels (tensor(batch, height, width)): labels for each pixel, values = 0...number of classes - 1
        include_background (bool, optional): if True include background class. Defaults to False.

    Returns:
        tensor[1]: loss
    """

    n_classes = predictions.shape[1]
    labels_oh = torch.nn.functional.one_hot(labels, num_classes=n_classes).permute((0,3,1,2))
    if apply_softamax:
        predictions = torch.nn.functional.softmax(predictions, dim = 1)

    # if not soft:
    #     predictions = torch.argmax(predictions, dim=1)
    #     predictions = nn.functional.one_hot(predictions, num_classes = n_classes).permute((0,3,1,2))

    if not include_background:
        predictions = predictions[:,1:, :, :]
        labels_oh = labels_oh[:,1:,:,:]

    intersection = torch.sum(predictions * labels_oh, dim=(1,2,3))
    union = torch.sum(predictions, dim=(1,2,3)) + torch.sum(labels_oh, dim=(1,2,3))
    dice_score = (2.0 * intersection) / (union)  # Adding a small epsilon to avoid division by zero
    return 1 - torch.mean(dice_score, dim=0)

# Define the DiceLoss class as previously described
class DiceLoss(torch.nn.Module):
    def __init__(self, smooth=1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, predictions, targets):
        predictions = torch.sigmoid(predictions)
        predictions = predictions.contiguous().view(predictions.shape[0], predictions.shape[1], -1)
        targets = targets.contiguous().view(targets.shape[0], targets.shape[1], -1)

        intersection = (predictions * targets).sum(2)
        denominator = predictions.sum(2) + targets.sum(2)

        dice = (2. * intersection + self.smooth) / (denominator + self.smooth)
        return 1 - dice.mean()

# Simulate predictions and targets
batch_size, num_classes, height, width = 4, 3, 256, 256  # Example dimensions

# Randomly generated predictions and targets for testing
predictions = torch.rand(batch_size, num_classes, height, width)

targets = torch.randint(0, 2, (batch_size, num_classes, height, width)).float()

# Initialize Dice Loss
dice_loss2 = DiceLoss()

# Calculate loss
loss = dice_loss2(predictions, targets)

print(f'Dice Loss: {loss.item()}')

loss = dice_loss(predictions, targets, True)

print(f'Dice Loss: {loss.item()}')

In [None]:
import torch
import torch.nn.functional as F

def dice_loss(predictions, labels, include_background = False, apply_softamax = False):
    """compute dice loss as a mean dice score across all classes

    Args:
        predictions (tensor(batch, num_clases, height, width)): predictions from segmentation models 0 to 1 probabilites for class
        labels (tensor(batch, height, width)): labels for each pixel, values = 0...number of classes - 1
        include_background (bool, optional): if True include background class. Defaults to False.

    Returns:
        tensor[1]: loss
    """

    n_classes = predictions.shape[1]
    labels_oh = torch.nn.functional.one_hot(labels, num_classes=n_classes).permute((0,3,1,2))
    if apply_softamax:
        predictions = torch.nn.functional.softmax(predictions, dim = 1)

    # if not soft:
    #     predictions = torch.argmax(predictions, dim=1)
    #     predictions = nn.functional.one_hot(predictions, num_classes = n_classes).permute((0,3,1,2))

    if not include_background:
        predictions = predictions[:,1:, :, :]
        labels_oh = labels_oh[:,1:,:,:]

    intersection = torch.sum(predictions * labels_oh, dim=(1,2,3))
    union = torch.sum(predictions, dim=(1,2,3)) + torch.sum(labels_oh, dim=(1,2,3))
    dice_score = (2.0 * intersection) / (union)  # Adding a small epsilon to avoid division by zero
    return 1 - torch.mean(dice_score, dim=0)

class DiceLoss(torch.nn.Module):
    def __init__(self, smooth=0):#1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, predictions, targets):
        predictions = torch.softmax(predictions, dim=1)
        predictions = predictions.contiguous().view(predictions.shape[0], predictions.shape[1], -1)
        targets = targets.contiguous().view(targets.shape[0], targets.shape[1], -1)

        intersection = (predictions * targets).sum(2)
        denominator = predictions.sum(2) + targets.sum(2)

        dice = (2. * intersection + self.smooth) / (denominator + self.smooth)
        return 1 - dice.mean()

# Simulate predictions and targets
batch_size, num_classes, height, width = 4, 3, 256, 256  # Example dimensions

# Random predictions
predictions = torch.rand(batch_size, num_classes, height, width)

# Random targets with class numbers
target_classes = torch.randint(0, num_classes, (batch_size, height, width))

# Convert targets to one-hot encoding
targets_one_hot = F.one_hot(target_classes, num_classes).permute(0, 3, 1, 2).float()

# Initialize Dice Loss
dice_loss2 = DiceLoss()

# Calculate loss
loss = dice_loss2(predictions, targets_one_hot)
# print(predictions)
# print(target_classes)
print(f'Dice Loss: {loss.item()}')
# Calculate loss
loss = dice_loss(predictions, target_classes, include_background=True, apply_softamax=True)

print(f'Dice Loss: {loss.item()}')