## Import packages

In [28]:
import numpy
import torch
import timeit
from scipy.ndimage import distance_transform_edt as distance
from scipy.ndimage import _nd_image

## Load test data

In [29]:
x = numpy.random.randint(2, size=(112, 112, 96))

## Pure feature transform

In [30]:
def feature_transform(input):
    # calculate the feature transform
    input = numpy.atleast_1d(numpy.where(input, 1, 0).astype(numpy.int8))
    ft = numpy.zeros((input.ndim,) + input.shape, dtype=numpy.int32)
    
    _nd_image.euclidean_feature_transform(input, None, ft)
    return ft

In [31]:
x_ft = feature_transform(x)

In [32]:
%%timeit 
feature_transform(x)

1 loop, best of 5: 223 ms per loop


## Pure distance transform

In [33]:
def distance_transform(input, ft):
    # calculate the distance transform
    dt = ft - numpy.indices(input.shape, dtype=ft.dtype)
    dt = dt.astype(numpy.float64)
    numpy.multiply(dt, dt, dt)
    dt = numpy.add.reduce(dt, axis=0)
    dt = numpy.sqrt(dt)

    # construct and return the result
    result = []
    result.append(dt)

    if len(result) == 2:
        return tuple(result)
    elif len(result) == 1:
        return result[0]
    else:
        return None

In [34]:
x_dt = distance_transform(x, x_ft)

In [35]:
%%timeit 
distance_transform(x, x_ft)

10 loops, best of 5: 20.9 ms per loop


## Full distance transform

In [36]:
x_dtm_pure = distance(x)

In [37]:
%%timeit
distance(x)

1 loop, best of 5: 249 ms per loop


## Modify pure distance transform

### Import functions

In [None]:
def distance_transform_edt(input):

    # calculate the feature transform
    input = numpy.atleast_1d(numpy.where(input, 1, 0).astype(numpy.int8))
    ft = numpy.zeros((input.ndim,) + input.shape, dtype=numpy.int32)

    _nd_image.euclidean_feature_transform(input, None, ft)

    # calculate the distance transform
    dt = ft - numpy.indices(input.shape, dtype=ft.dtype)
    dt = dt.astype(numpy.float64)
    numpy.multiply(dt, dt, dt)
    dt = numpy.add.reduce(dt, axis=0)
    dt = numpy.sqrt(dt)

    # construct and return the result
    result = []
    result.append(dt)

    if len(result) == 2:
        return tuple(result)
    elif len(result) == 1:
        return result[0]
    else:
        return None

In [None]:
x_dtm_new = distance_transform_edt(x)

In [None]:
%%time 
x_dtm_new = distance_transform_edt(x)

CPU times: user 269 ms, sys: 2.49 ms, total: 272 ms
Wall time: 276 ms
