## Import packages

In [11]:
import torch
import numpy as np
import time
from scipy.ndimage import distance_transform_edt as distance
from scipy.ndimage import _nd_image

In [12]:
device = torch.device("cuda:0")
print(device)

cuda:0


# Loss Functions

In [13]:
def dice_loss(score, target):
    target = target.float()
    smooth = 1e-5
    intersect = torch.sum(score * target)
    y_sum = torch.sum(target * target)
    z_sum = torch.sum(score * score)
    loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
    loss = 1 - loss
    return loss



In [14]:
def hd_loss_2D(seg_soft, gt, seg_dtm, gt_dtm):
    """
    compute huasdorff distance loss for binary segmentation
    input: seg_soft: softmax results,  shape=(b,2,x,y,z)
           gt: ground truth, shape=(b,x,y,z)
           seg_dtm: segmentation distance transform map; shape=(b,2,x,y,z)
           gt_dtm: ground truth distance transform map; shape=(b,2,x,y,z)
    output: boundary_loss; sclar
    """

    delta_s = (seg_soft - gt) ** 2
    s_dtm = seg_dtm ** 2
    g_dtm = gt_dtm ** 2
    dtm = s_dtm + g_dtm
    multipled = torch.einsum('xy, xy->xy', delta_s, dtm)
    hd_loss = multipled.mean()

    return hd_loss

In [15]:
def hd_loss_3D(seg_soft, gt, seg_dtm, gt_dtm):
    """
    compute huasdorff distance loss for binary segmentation
    input: seg_soft: softmax results,  shape=(b,2,x,y,z)
           gt: ground truth, shape=(b,x,y,z)
           seg_dtm: segmentation distance transform map; shape=(b,2,x,y,z)
           gt_dtm: ground truth distance transform map; shape=(b,2,x,y,z)
    output: boundary_loss; sclar
    """

    delta_s = (seg_soft - gt) ** 2
    s_dtm = seg_dtm ** 2
    g_dtm = gt_dtm ** 2
    dtm = s_dtm + g_dtm
    multipled = torch.einsum('xyz, xyz->xyz', delta_s, dtm)
    hd_loss = multipled.mean()

    return hd_loss

# Benchmarks

In [16]:
num_range = range(1, 510, 10)

In [17]:
sizes_2D = []
sizes_3D = []

hd_mean_2D = []
hd_std_2D = []
	
dice_mean_2D = []
dice_std_2D = []
	
hd_mean_3D = []
hd_std_3D = []
	
dice_mean_3D = []
dice_std_3D = []

hd_times = []
dice_times = []
hd_3D_times = []
dice_3D_times = []

In [18]:
test_cases_2D = []

for n in num_range:
    _size = n**2
    sizes_2D.append(_size)
    test_cases_2D.append(torch.randint(0,2,(n,n)))

In [19]:
test_cases_3D = []

for n in num_range:
    _size = n**3
    sizes_3D.append(_size)
    test_cases_3D.append(torch.randint(0,2,(n,n,n)))

In [21]:
#2D
for array in test_cases_2D:
  # HD GPU
  print("HD Size: ",  len(array))
  tfm1 = torch.from_numpy(distance(array)).cuda()
  n = len(array)
  g = torch.randint(0,2,(n,n))
  tfm2 = torch.from_numpy(distance(g)).cuda()
  g1 = g.cuda()
  array1 = array.cuda()
  a = []
  for j in range(1000): #Evaluations
    times1 = time.time_ns()
    hd = hd_loss_2D(array1, g1, tfm1, tfm2)
    times2 = time.time_ns()
    a.append(times2-times1)
    if sum(a) > (15*60*(10**9)):
      break
  hd_times.append(a)

for array in test_cases_2D:
  # Dice GPU
  print("Dice Size: ",  len(array))
  b = []
  n = len(array)
  g = torch.randint(0,2,(n,n)).cuda()
  array1 = array.cuda()
  for j in range(1000): #Evaluations
    times1 = time.time_ns()
    dice = dice_loss(array1, g)
    times2 = time.time_ns()
    b.append(times2-times1)
    if sum(b) > (20*60*(10**9)):
      break
  dice_times.append(b)
  



HD Size:  1


AssertionError: Torch not compiled with CUDA enabled

In [None]:
# 3D
for array in test_cases_3D:
  # HD GPU
  print("HD Size: ", len(array))
  a = []
  tfm1 = torch.from_numpy(distance(array)).to(device)
  n = len(array)
  g = torch.randint(0,2,(n,n,n))
  tfm2 = torch.from_numpy(distance(g)).to(device)
  array1 = array.to(device)
  g1 = g.to(device)
  for j in range(1000): #Evaluations
    times1 = time.time_ns()
    hd = hd_loss_3D(array1, g1, tfm1, tfm2)
    times2 = time.time_ns()
    a.append(times2-times1)
    if sum(a) > (20*60*(10**9)):
      break
  hd_3D_times.append(a)

for array in test_cases_3D:
  # Dice GPU
  print("Dice Size: ",  len(array))
  b = []
  n = len(array)
  g = torch.randint(0,2,(n,n,n)).to(device)
  array1 = array.to(device)
  for j in range(1000): #Evaluations
    times1 = time.time_ns()
    dice = dice_loss(array1, g)
    times2 = time.time_ns()
    b.append(times2-times1)
    if sum(b) > (20*60*(10**9)):
      break
  dice_3D_times.append(b)
  

HD Size:  1
HD Size:  11
HD Size:  21
HD Size:  31
HD Size:  41
HD Size:  51
HD Size:  61
HD Size:  71
HD Size:  81
HD Size:  91
HD Size:  101
HD Size:  111
HD Size:  121
HD Size:  131
HD Size:  141
HD Size:  151
HD Size:  161
HD Size:  171
HD Size:  181
HD Size:  191
HD Size:  201
HD Size:  211
HD Size:  221
HD Size:  231
HD Size:  241
HD Size:  251
HD Size:  261
HD Size:  271
HD Size:  281
HD Size:  291
HD Size:  301
HD Size:  311
HD Size:  321
HD Size:  331
HD Size:  341
HD Size:  351
HD Size:  361
HD Size:  371
HD Size:  381
HD Size:  391
HD Size:  401
HD Size:  411
HD Size:  421
HD Size:  431
HD Size:  441
HD Size:  451
HD Size:  461
HD Size:  471
HD Size:  481
HD Size:  491
HD Size:  501
Dice Size:  1
Dice Size:  11
Dice Size:  21
Dice Size:  31
Dice Size:  41
Dice Size:  51
Dice Size:  61
Dice Size:  71
Dice Size:  81
Dice Size:  91
Dice Size:  101
Dice Size:  111
Dice Size:  121
Dice Size:  131
Dice Size:  141
Dice Size:  151
Dice Size:  161
Dice Size:  171
Dice Size:  181
Dice

In [None]:
for i in hd_times:
  hd_mean_2D.append(torch.mean(torch.FloatTensor(i)).numpy().tolist())
  hd_std_2D.append(torch.std(torch.FloatTensor(i), unbiased = False).numpy().tolist())

In [None]:
for i in dice_times:
  dice_mean_2D.append(torch.mean(torch.FloatTensor(i)).numpy().tolist())
  dice_std_2D.append(torch.std(torch.FloatTensor(i), unbiased = False).numpy().tolist())

In [None]:
for i in hd_3D_times:
  hd_mean_3D.append(torch.mean(torch.FloatTensor(i)).numpy().tolist())
  hd_std_3D.append(torch.std(torch.FloatTensor(i), unbiased = False).numpy().tolist())

In [None]:
for i in dice_3D_times:
  dice_mean_3D.append(torch.mean(torch.FloatTensor(i)).numpy().tolist())
  dice_std_3D.append(torch.std(torch.FloatTensor(i), unbiased = False).numpy().tolist())

### Save

In [None]:
import pandas as pd

In [None]:
hd_mean = np.array(hd_mean_2D)
dice_mean = np.array(dice_mean_2D)
hd_3D_mean = np.array(hd_mean_3D)
dice_3D_mean = np.array(dice_mean_3D)
hd_std = np.array(hd_std_2D)
dice_std = np.array(dice_std_2D)
hd_3D_std = np.array(hd_std_3D)
dice_3D_std = np.array(dice_std_3D)

In [None]:
data2D = {'sizes_2D': sizes_2D, 'hd_mean_2D_purePython_GPU': hd_mean, 'dice_mean_2D_purePython_GPU': dice_mean, 'hd_std_2D_purePython_GPU': hd_std, 'dice_std_2D_purePython_GPU': dice_std}
data3D = {'sizes_3D': sizes_3D, 'hd_mean_3D_purePython_GPU': hd_3D_mean, 'dice_mean_3D_purePython_GPU': dice_3D_mean, 'hd_std_3D_purePython_GPU': hd_3D_std, 'dice_std_3D_purePython_GPU':dice_3D_std}

In [None]:
dataframe2D = pd.DataFrame(data2D)
dataframe3D = pd.DataFrame(data3D)

In [None]:
dataframe2D.to_csv("C:/Users/wenbl13/Desktop/Ashwin-Timing/distance-transforms/purePython_Loss_2D_nov9_GPU.csv")
dataframe3D.to_csv("C:/Users/wenbl13/Desktop/Ashwin-Timing/distance-transforms/purePython_Loss_3D_nov9_GPU.csv")

In [None]:
dataframe2D

Unnamed: 0,sizes_2D,hd_mean_2D_purePython,dice_mean_2D_purePython,hd_std_2D_purePython,dice_std_2D_purePython
0,3,44196.5,65002.8,219130.7,246537.09375
1,9,35026.5,74471.3,183943.6,249019.046875
2,23,39053.6,76000.7,193897.8,265001.5625
3,29,42959.9,64001.2,202715.1,244758.734375
4,43,47918.4,62047.0,213446.2,237418.171875
5,49,58895.4,61089.0,235252.0,239535.453125
6,63,69987.0,61992.8,255112.8,241202.296875
7,69,71929.91,64986.0,262155.8,246509.75
8,83,90999.39,67024.8,287639.8,250190.0625
9,89,96051.01,73993.3,294797.3,261825.0


In [None]:
x = [i for i in range(1, 1000, 100)]

In [None]:
import matplotlib.pyplot as plt


In [None]:
# plt.figure(figsize=(13, 13))
# plt.plot(x, dataframe['hd_mean_2D_purePython'], label = 'hd_mean_2D')
# plt.plot(x, dataframe['dice_mean_2D_purePython'], label = 'dice_mean_2D')
# plt.plot(x, dataframe['hd_mean_3D_purePython'], label = 'hd_mean_3D')
# plt.plot(x, dataframe['dice_mean_3D_purePython'], label = 'dice_mean_3D')
# plt.xlabel('Array_Size')
# plt.ylabel('Time (seconds)')
# plt.legend()
# plt.show()