In [None]:
import os
import random
import numpy as np

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo

from xception import Xception

# Dice similarity coefficient

In [None]:
def dice_coefficient(prediction, target, epsilon=1e-7):
    """
    Compute Dice similarity coefficient.

    Args:
        prediction (torch.Tensor): 모델의 예측값 (0 또는 1의 이진 값)
        target (torch.Tensor): 실제 타겟 값 (0 또는 1의 이진 값)
        epsilon (float): 분모가 0이 되는 것을 방지하기 위한 작은 값

    Returns:
        float: Dice similarity coefficient
    """
    intersection = torch.sum(prediction * target)
    union = torch.sum(prediction) + torch.sum(target)
    dice_score = (2. * intersection + epsilon) / (union + epsilon)
    return dice_score.item()

# 예제를 위한 가짜 데이터 생성
prediction = torch.tensor([[0, 1, 1],
                           [1, 1, 0],
                           [0, 1, 0]], dtype=torch.float32)
target = torch.tensor([[1, 1, 0],
                       [1, 0, 0],
                       [0, 1, 1]], dtype=torch.float32)

# Dice similarity coefficient 계산
dice_score = dice_coefficient(prediction, target)
print("Dice similarity coefficient:", dice_score)


# dice loss

In [None]:
def dice_loss(prediction, target, epsilon=1e-7):
    """
    Compute Dice Loss.

    Args:
        prediction (torch.Tensor): 모델의 예측값 (확률값)
        target (torch.Tensor): 실제 타겟 값 (0 또는 1의 이진 값)
        epsilon (float): 분모가 0이 되는 것을 방지하기 위한 작은 값

    Returns:
        torch.Tensor: Dice Loss
    """
    intersection = torch.sum(prediction * target)
    union = torch.sum(prediction) + torch.sum(target)
    dice_score = (2. * intersection + epsilon) / (union + epsilon)
    dice_loss = 1 - dice_score
    return dice_loss

# 예제를 위한 가짜 데이터 생성
prediction = torch.tensor([[0.8, 0.7, 0.6],
                           [0.3, 0.4, 0.5],
                           [0.2, 0.9, 0.1]], dtype=torch.float32)
target = torch.tensor([[1, 1, 0],
                       [1, 0, 0],
                       [0, 1, 1]], dtype=torch.float32)

# Dice Loss 계산
loss = dice_loss(prediction, target)
print("Dice Loss:", loss.item())
