In [1]:
import torch

In [2]:
d, h, w = 2, 5, 5
x = torch.arange(d*h*w).view(1, 1, d, h, w)
print(x)


tensor([[[[[ 0,  1,  2,  3,  4],
           [ 5,  6,  7,  8,  9],
           [10, 11, 12, 13, 14],
           [15, 16, 17, 18, 19],
           [20, 21, 22, 23, 24]],

          [[25, 26, 27, 28, 29],
           [30, 31, 32, 33, 34],
           [35, 36, 37, 38, 39],
           [40, 41, 42, 43, 44],
           [45, 46, 47, 48, 49]]]]])


In [3]:
weight = (2**torch.arange(8)).view(1, 1, 2, 2, 2).long()
torch.nn.functional.conv3d(x, weight, padding=(0, 1, 1)).flatten()

tensor([ 3200,  4936,  5140,  5344,  5548,  1872,  4680,  7190,  7445,  7700,
         7955,  2680,  5530,  8465,  8720,  8975,  9230,  3105,  6380,  9740,
         9995, 10250, 10505,  3530,  7230, 11015, 11270, 11525, 11780,  3955,
         1480,  2254,  2305,  2356,  2407,   808])

### Using Unfold

In [4]:
torch_ver_major = int(torch.__version__.split('.')[0])
dtype_index = torch.int32 if torch_ver_major >= 2 else torch.long

unfold = torch.nn.Unfold(kernel_size=(2, 2), padding=1)
x = torch.arange(d*h*w).view(d, h, w)
x

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19],
         [20, 21, 22, 23, 24]],

        [[25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39],
         [40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49]]])

In [5]:
x = x.to(torch.float16).unsqueeze(0)
cubes_float = unfold(x).squeeze(0)
cubes_byte = torch.zeros(cubes_float.size(1), dtype=dtype_index)

for k in range(8):
    cubes_byte += cubes_float[k, :].to(dtype_index) << k

cubes_byte

tensor([ 3200,  4936,  5140,  5344,  5548,  1872,  4680,  7190,  7445,  7700,
         7955,  2680,  5530,  8465,  8720,  8975,  9230,  3105,  6380,  9740,
         9995, 10250, 10505,  3530,  7230, 11015, 11270, 11525, 11780,  3955,
         1480,  2254,  2305,  2356,  2407,   808], dtype=torch.int32)

### Surface dice using Conv3D

In [None]:
def compute_surface_area(self, surface):
    d, h, w = surface.shape
    weight = (2**torch.arange(8)).view(1, 1, 2, 2, 2).long()

    surface = surface.view(1, 1, d, h, w)
    cubes_byte = torch.nn.functional.conv3d(surface, weight, padding=(0, 1, 1)).flatten()

    cubes_area = self.area[cubes_byte]

    return cubes_area

def process_batch(self, pred, target):
    bs, h, w = pred.shape
    pad_d = int(bs % 2 == 0) + 1
    pad = torch.zeros((pad_d, h, w), dtype=torch.uint8, device=self.device)

    if self.batch_idx == 0:
        pred = torch.vstack([pad, pred])
        target = torch.vstack([pad, target])
    elif self.batch_idx == self.n_batches - 1:
        pred = torch.vstack([pred, pad])
        target = torch.vstack([target, pad])
    else:
        pred = torch.vstack([self.pred_pad, pred])
        target = torch.vstack([self.target_pad, target])

    
    area_pred = self.compute_surface_area(pred)
    area_true = self.compute_surface_area(target)

    idx = torch.logical_and(area_pred > 0, area_true > 0)

    self.numerator += area_pred[idx].sum() + area_true[idx].sum()
    self.denominator += area_pred.sum() + area_true.sum()

    self.batch_idx += 1
    self.pred_pad = pred[-1:]
    self.target_pad = target[-1:]