In [None]:
import torch
import torch.nn as nn

In [None]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 1)

    def forward(self,x):
        x = torch.relu(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        return x
#         return torch.norm(x,dim=-1,keepdim=True)

In [None]:
model = Network()

In [None]:
x = torch.randn([10,2], requires_grad=True)

y = model(x)
print(y.shape)

In [None]:
def make_second_order_deriv(y, x):
        
    # its hardcoded for 2 variables
    assert x.shape[1] == 2
    
    # compute first order deriv
    dy_xy = torch.autograd.grad(y, x, torch.ones_like(y), create_graph=True)[0]
    dy_x, dy_y = dy_xy.split(1,-1)
    # compute secord order deriv
    dy_x_xy = torch.autograd.grad(dy_x, x, torch.ones_like(dy_x), create_graph=True)[0]
    dy_xx, dy_xy = dy_x_xy.split(1,-1)
    dy_y_xy = torch.autograd.grad(dy_y, x, torch.ones_like(dy_y), create_graph=True)[0]
    dy_yx, dy_yy = dy_y_xy.split(1,-1)

    # gather results in a matrix Bx2x2 in the form of Div-free kernel
    K1 = torch.cat([-dy_yy, dy_xy], dim=-1)[...,None]
    K2 = torch.cat([dy_yx, -dy_xx], dim=-1)[...,None]
    K = torch.cat([K1, K2], dim=-1)
    return K

uv = make_second_order_deriv(y, x)
coeff = torch.randn(2,1)
pred = torch.einsum('bxy,yw->bx', uv, coeff)

print(uv.shape)
print(coeff.shape)
print(pred.shape)

In [None]:
# check if it is divergence free
u, v = torch.split(pred,1,-1)
du_xy = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0]       
dv_xy = torch.autograd.grad(v, x, torch.ones_like(v), create_graph=True)[0]
# last two columns are x and y
div_u_xy = du_xy[...,0] + dv_xy[...,1]
print(div_u_xy)

torch.Size([10, 1])


In [None]:
class DivFreeNet(nn.Module):
    def __init__(self):
        super(DivFreeNet, self).__init__()

        
    def divfree_2D_output(self, y, c, x):
        
        print(y.shape)
        print(c.shape)
        print(x.shape)
        
        assert y.shape[0] == x.shape[0] == c.shape[0] # check batch size
        assert y.shape[1] == 1 # one variable
        assert c.shape[1] == 2 # two coefficients for the linear combination
        assert x.shape[1] == 2 # x,y varialbe
        
        u = y
        
        # compute first order deriv
        du_xy = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0]
        print(du_xy.shape)
        du_x, du_y = du_xy.split(1,-1)
        # compute secord order deriv
        du_x_xy = torch.autograd.grad(du_x, x, torch.ones_like(du_x), create_graph=True)[0]
        du_xx, du_xy = du_x_xy.split(1,-1)
        du_y_xy = torch.autograd.grad(du_y, x, torch.ones_like(du_y), create_graph=True)[0]
        du_yx, du_yy = du_y_xy.split(1,-1)

        # gather results in a matrix Bx2x2 in the form of Div-free kernel
        K1 = torch.cat([-du_yy, du_xy], dim=-1)[...,None]
        K2 = torch.cat([du_yx, -du_xx], dim=-1)[...,None]
        K = torch.cat([K1, K2], dim=-1)
        
        pred = torch.einsum('bxy,bw->bx', K, c)
        
        return pred

    def forward(self, pred, coeff, inputs):
        x = self.divfree_2D_output(pred, coeff, inputs)
        return x