## Import packages

In [21]:
import torch
import numpy as np
import time
from scipy.ndimage import distance_transform_edt as distance
from scipy.ndimage import _nd_image

In [22]:
device = torch.device("cuda:0")
print(device)

cuda:0


# Loss Functions

In [23]:
def dice_loss(score, target):
    target = target.float()
    smooth = 1e-5
    intersect = torch.sum(score * target)
    y_sum = torch.sum(target * target)
    z_sum = torch.sum(score * score)
    loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
    loss = 1 - loss
    return loss



In [24]:
def hd_loss_2D(seg_soft, gt, seg_dtm, gt_dtm):
    """
    compute huasdorff distance loss for binary segmentation
    input: seg_soft: softmax results,  shape=(b,2,x,y,z)
           gt: ground truth, shape=(b,x,y,z)
           seg_dtm: segmentation distance transform map; shape=(b,2,x,y,z)
           gt_dtm: ground truth distance transform map; shape=(b,2,x,y,z)
    output: boundary_loss; sclar
    """

    delta_s = (seg_soft - gt) ** 2
    s_dtm = seg_dtm ** 2
    g_dtm = gt_dtm ** 2
    dtm = s_dtm + g_dtm
    multipled = torch.einsum('xy, xy->xy', delta_s, dtm)
    hd_loss = multipled.mean()

    return hd_loss

In [25]:
def hd_loss_3D(seg_soft, gt, seg_dtm, gt_dtm):
    """
    compute huasdorff distance loss for binary segmentation
    input: seg_soft: softmax results,  shape=(b,2,x,y,z)
           gt: ground truth, shape=(b,x,y,z)
           seg_dtm: segmentation distance transform map; shape=(b,2,x,y,z)
           gt_dtm: ground truth distance transform map; shape=(b,2,x,y,z)
    output: boundary_loss; sclar
    """

    delta_s = (seg_soft - gt) ** 2
    s_dtm = seg_dtm ** 2
    g_dtm = gt_dtm ** 2
    dtm = s_dtm + g_dtm
    multipled = torch.einsum('xyz, xyz->xyz', delta_s, dtm)
    hd_loss = multipled.mean()

    return hd_loss

# Benchmarks

In [26]:
num_range = range(1, 510, 10)

In [27]:
sizes_2D = []
sizes_3D = []

hd_min_2D = []
hd_std_2D = []
	
dice_min_2D = []
dice_std_2D = []
	
hd_min_3D = []
hd_std_3D = []
	
dice_min_3D = []
dice_std_3D = []

hd_times = []
dice_times = []
hd_3D_times = []
dice_3D_times = []

In [28]:
test_cases_2D = []

for n in num_range:
    _size = n**2
    sizes_2D.append(_size)
    test_cases_2D.append(torch.randint(0,2,(n,n)))

In [29]:
test_cases_3D = []

for n in num_range:
    _size = n**3
    sizes_3D.append(_size)
    test_cases_3D.append(torch.randint(0,2,(n,n,n)))

In [30]:
#2D
for array in test_cases_2D:
  # HD CPU
  print("HD Size: ",  len(array))
  tfm1 = torch.from_numpy(distance(array))
  n = len(array)
  g = torch.randint(0,2,(n,n))
  tfm2 = torch.from_numpy(distance(g))
  a = []
  for j in range(1000): #Evaluations
    times1 = time.perf_counter_ns()
    hd = hd_loss_2D(array, g, tfm1, tfm2)
    times2 = time.perf_counter_ns()
    a.append(times2-times1)
    if sum(a) > (15*60*(10**9)):
      break
  hd_times.append(a)

for array in test_cases_2D:
  # Dice CPU
  print("Dice Size: ",  len(array))
  b = []
  n = len(array)
  g = torch.randint(0,2,(n,n))
  for j in range(1000): #Evaluations
    times1 = time.perf_counter_ns()
    dice = dice_loss(array, g)
    times2 = time.perf_counter_ns()
    b.append(times2-times1)
    if sum(b) > (20*60*(10**9)):
      break
  dice_times.append(b)
  



HD Size:  1
HD Size:  11
HD Size:  21
HD Size:  31
HD Size:  41
HD Size:  51
HD Size:  61
HD Size:  71
HD Size:  81
HD Size:  91
HD Size:  101
HD Size:  111
HD Size:  121
HD Size:  131
HD Size:  141
HD Size:  151
HD Size:  161
HD Size:  171
HD Size:  181
HD Size:  191
HD Size:  201
HD Size:  211
HD Size:  221
HD Size:  231
HD Size:  241
HD Size:  251
HD Size:  261
HD Size:  271
HD Size:  281
HD Size:  291
HD Size:  301
HD Size:  311
HD Size:  321
HD Size:  331
HD Size:  341
HD Size:  351
HD Size:  361
HD Size:  371
HD Size:  381
HD Size:  391
HD Size:  401
HD Size:  411
HD Size:  421
HD Size:  431
HD Size:  441
HD Size:  451
HD Size:  461
HD Size:  471
HD Size:  481
HD Size:  491
HD Size:  501
Dice Size:  1
Dice Size:  11
Dice Size:  21
Dice Size:  31
Dice Size:  41
Dice Size:  51
Dice Size:  61
Dice Size:  71
Dice Size:  81
Dice Size:  91
Dice Size:  101
Dice Size:  111
Dice Size:  121
Dice Size:  131
Dice Size:  141
Dice Size:  151
Dice Size:  161
Dice Size:  171
Dice Size:  181
Dice

In [31]:
# 3D
for array in test_cases_3D:
  # HD CPU
  print("HD Size: ", len(array))
  a = []
  tfm1 = torch.from_numpy(distance(array))
  n = len(array)
  g = torch.randint(0,2,(n,n,n))
  tfm2 = torch.from_numpy(distance(g))
  for j in range(1000): #Evaluations
    times1 = time.perf_counter_ns()
    hd = hd_loss_3D(array, g, tfm1, tfm2)
    times2 = time.perf_counter_ns()
    a.append(times2-times1)
    if sum(a) > (20*60*(10**9)):
      break
  hd_3D_times.append(a)

for array in test_cases_3D:
  # Dice CPU
  print("Dice Size: ",  len(array))
  b = []
  n = len(array)
  g = torch.randint(0,2,(n,n,n))
  for j in range(1000): #Evaluations
    times1 = time.perf_counter_ns()
    dice = dice_loss(array, g)
    times2 = time.perf_counter_ns()
    b.append(times2-times1)
    if sum(b) > (20*60*(10**9)):
      break
  dice_3D_times.append(b)
  

HD Size:  1
HD Size:  11
HD Size:  21
HD Size:  31
HD Size:  41
HD Size:  51
HD Size:  61
HD Size:  71
HD Size:  81
HD Size:  91
HD Size:  101
HD Size:  111
HD Size:  121
HD Size:  131
HD Size:  141
HD Size:  151
HD Size:  161
HD Size:  171
HD Size:  181
HD Size:  191
HD Size:  201
HD Size:  211
HD Size:  221
HD Size:  231
HD Size:  241
HD Size:  251
HD Size:  261
HD Size:  271
HD Size:  281
HD Size:  291
HD Size:  301
HD Size:  311
HD Size:  321
HD Size:  331
HD Size:  341
HD Size:  351
HD Size:  361
HD Size:  371
HD Size:  381
HD Size:  391
HD Size:  401
HD Size:  411
HD Size:  421
HD Size:  431
HD Size:  441
HD Size:  451
HD Size:  461
HD Size:  471
HD Size:  481
HD Size:  491
HD Size:  501
Dice Size:  1
Dice Size:  11
Dice Size:  21
Dice Size:  31
Dice Size:  41
Dice Size:  51
Dice Size:  61
Dice Size:  71
Dice Size:  81
Dice Size:  91
Dice Size:  101
Dice Size:  111
Dice Size:  121
Dice Size:  131
Dice Size:  141
Dice Size:  151
Dice Size:  161
Dice Size:  171
Dice Size:  181
Dice

In [32]:
for i in hd_times:
  hd_min_2D.append(torch.min(torch.FloatTensor(i)).numpy().tolist())
  hd_std_2D.append(torch.std(torch.FloatTensor(i), unbiased = False).numpy().tolist())

In [33]:
for i in dice_times:
  dice_min_2D.append(torch.min(torch.FloatTensor(i)).numpy().tolist())
  dice_std_2D.append(torch.std(torch.FloatTensor(i), unbiased = False).numpy().tolist())

In [34]:
for i in hd_3D_times:
  hd_min_3D.append(torch.min(torch.FloatTensor(i)).numpy().tolist())
  hd_std_3D.append(torch.std(torch.FloatTensor(i), unbiased = False).numpy().tolist())

In [35]:
for i in dice_3D_times:
  dice_min_3D.append(torch.min(torch.FloatTensor(i)).numpy().tolist())
  dice_std_3D.append(torch.std(torch.FloatTensor(i), unbiased = False).numpy().tolist())

### Save

In [36]:
import pandas as pd

In [37]:
hd_min = np.array(hd_min_2D)
dice_min = np.array(dice_min_2D)
hd_3D_min = np.array(hd_min_3D)
dice_3D_min = np.array(dice_min_3D)
hd_std = np.array(hd_std_2D)
dice_std = np.array(dice_std_2D)
hd_3D_std = np.array(hd_std_3D)
dice_3D_std = np.array(dice_std_3D)

In [38]:
data2D = {'sizes_2D': sizes_2D, 'hd_min_2D_purePython': hd_min, 'dice_min_2D_purePython': dice_min, 'hd_std_2D_purePython': hd_std, 'dice_std_2D_purePython': dice_std}
data3D = {'sizes_3D': sizes_3D, 'hd_min_3D_purePython': hd_3D_min, 'dice_min_3D_purePython': dice_3D_min, 'hd_std_3D_purePython': hd_3D_std, 'dice_std_3D_purePython':dice_3D_std}

In [39]:
dataframe2D = pd.DataFrame(data2D)
dataframe3D = pd.DataFrame(data3D)

In [40]:
dataframe2D.to_csv("C:/Users/wenbl13/Desktop/Ashwin-Timing/distance-transforms/purePython_Loss_2D_nov6.csv")
dataframe3D.to_csv("C:/Users/wenbl13/Desktop/Ashwin-Timing/distance-transforms/purePython_Loss_3D_nov6.csv")

In [41]:
dataframe2D

Unnamed: 0,sizes_2D,hd_min_2D_purePython,dice_min_2D_purePython,hd_std_2D_purePython,dice_std_2D_purePython
0,1,29400.0,48000.0,49707.75,319482.9
1,121,31400.0,49000.0,10313.63,8556.633
2,441,33300.0,55400.0,10261.51,9789.611
3,961,36900.0,53100.0,11097.37,27184.7
4,1681,41200.0,51800.0,34685.68,43157.58
5,2601,49900.0,54900.0,17486.16,9581.366
6,3721,56300.0,58700.0,15318.79,3549.291
7,5041,65000.0,62100.0,26446.97,34375.0
8,6561,73400.0,66500.0,17146.78,8827.031
9,8281,84300.0,71800.0,18670.01,4629.172


In [42]:
x = [i for i in range(1, 1000, 100)]

In [43]:
import matplotlib.pyplot as plt


In [44]:
# plt.figure(figsize=(13, 13))
# plt.plot(x, dataframe['hd_min_2D_purePython'], label = 'hd_min_2D')
# plt.plot(x, dataframe['dice_min_2D_purePython'], label = 'dice_min_2D')
# plt.plot(x, dataframe['hd_min_3D_purePython'], label = 'hd_min_3D')
# plt.plot(x, dataframe['dice_min_3D_purePython'], label = 'dice_min_3D')
# plt.xlabel('Array_Size')
# plt.ylabel('Time (seconds)')
# plt.legend()
# plt.show()