In [1]:
import torch
import numpy as np
import phi.torch.flow as ptf

device = torch.device('cuda')

In [36]:
BOUNDS = ptf.Box(x=1, y=1, z=1)
RESOLUTION = 100

TO_LEARN_DIRECTION = [
    torch.randn(3, device=device, requires_grad=True),
    torch.randn(3, device=device, requires_grad=True),
    torch.randn(3, device=device, requires_grad=True),
]

LOCATIONS = [
    ptf.vec(x=0, y=0, z=0),
    ptf.vec(x=0.5, y=0.5, z=0.5),
    ptf.vec(x=1, y=1, z=1),
]
DIRECTIONS = [
    ptf.wrap(TO_LEARN_DIRECTION[0], ptf.channel(vector='x,y,z')),
    ptf.wrap(TO_LEARN_DIRECTION[1], ptf.channel(vector='x,y,z')),
    ptf.wrap(TO_LEARN_DIRECTION[2], ptf.channel(vector='x,y,z')),
]

FORCE_FIELD = ptf.CenteredGrid(
    values=(0, 0, 0),
    extrapolation=ptf.extrapolation.ZERO_GRADIENT,
    bounds=BOUNDS,
    resolution=ptf.spatial(x=RESOLUTION, y=RESOLUTION, z=RESOLUTION),
)
for LOCATION, DIRECTION in zip(LOCATIONS, DIRECTIONS):
    LAMBDA = ptf.CenteredGrid(
        values=lambda loc: 0.2 - ptf.length(loc - LOCATION),
        extrapolation=ptf.extrapolation.ZERO_GRADIENT,
        bounds=BOUNDS,
        resolution=ptf.spatial(x=RESOLUTION, y=RESOLUTION, z=RESOLUTION),
    )
    MARK = ptf.CenteredGrid(
        values=lambda loc: ptf.length(loc - LOCATION) < 0.2,
        extrapolation=ptf.extrapolation.ZERO_GRADIENT,
        bounds=BOUNDS,
        resolution=ptf.spatial(x=RESOLUTION, y=RESOLUTION, z=RESOLUTION),
    )
    VECTOR_FIELD = ptf.CenteredGrid(
        values=DIRECTION,
        extrapolation=ptf.extrapolation.ZERO_GRADIENT,
        bounds=BOUNDS,
        resolution=ptf.spatial(x=RESOLUTION, y=RESOLUTION, z=RESOLUTION),
    )
    FORCE_FIELD = FORCE_FIELD + LAMBDA * MARK * VECTOR_FIELD

loss = FORCE_FIELD.values.native('x,y,z,vector').sum()
loss.backward()

## Check Gradient

In [37]:
grads = torch.autograd.grad(
    outputs=loss,
    inputs=TO_LEARN_DIRECTION,
    retain_graph=True,
    allow_unused=True
)
for i, grad in enumerate(grads):
    print(f"Gradient for TO_LEARN_DIRECTION[{i}]: {grad}")

Gradient for TO_LEARN_DIRECTION[0]: tensor([209.4329, 209.4329, 209.4329], device='cuda:0')
Gradient for TO_LEARN_DIRECTION[1]: tensor([1675.4626, 1675.4626, 1675.4626], device='cuda:0')
Gradient for TO_LEARN_DIRECTION[2]: tensor([209.4330, 209.4330, 209.4330], device='cuda:0')
