In [None]:
import torch


def rdpz_rolled_components(x):
    neighbours = [1,-1]
    x_neighbours = x.clone()[None].repeat(2,1,1,1)
    for i in range(2):
        x_neighbours[[i],:] = torch.roll(x, shifts=(neighbours[i]), dims=(0))
    x_neighbours[0,:] = x[1]
    x_neighbours[-1,:] = x[-2]
    return x_neighbours

def get_rdpz_grad(x):
    x_neighbours = rdpz_rolled_components(x)
    x = x[None].repeat(2,1,1,1)
    numerator = (x - x_neighbours)*(2 * torch.abs(x-x_neighbours) + x + 3 * x_neighbours)
    denominator = (x + x_neighbours + 2 * torch.abs(x - x_neighbours))**2 + 1e-9
    gradient = - (numerator/denominator)
    gradient = gradient.sum(dim=0)
    return gradient

def get_rdp_value(x):
    rdp = torch.zeros_like(x)
    # we want relected in axial
    x_neighbours = torch.nn.functional.pad(x[None], (1,1,1,1,1,1), mode='reflect')[0]
    for i in [0,1,2]:
        for j in [0,1,2]:
            for k in [0,1,2]:
                i_end = x_neighbours.shape[0] - (2 - i)
                j_end = x_neighbours.shape[1] - (2 - j)
                k_end = x_neighbours.shape[2] - (2 - k)
                numerator = (x - x_neighbours[i:i_end,j:j_end,k:k_end])**2
                denominator = (x + x_neighbours[i:i_end,j:j_end,k:k_end] + \
                                2 * torch.abs(x - x_neighbours[i:i_end,j:j_end,k:k_end]))
                rdp -= (numerator/torch.clamp(denominator, 1e-9))
    return rdp.sum()

torch.random.manual_seed(1337)
x = torch.zeros(5,5,5).cuda()
x[1:4,1:4,1:4] = 10000.
rdp_val = get_rdp_value(x)
rdp_grad = get_rdpz_grad(x)
print(rdp_val)
print(x[3])

In [None]:
import torch
import rdp

z = rdp.compute_value(x.cuda(), torch.ones(3,3,3).cuda())
print(z)