In [3]:
%matplotlib inline
import numpy as np
import torch
from torch import nn

In [54]:
class NNSlaterDet(nn.Module):
    def __init__(self, n_configs, config_size, num_basis):
        super(NNSlaterDet, self).__init__()

        # configurations of shape N x M x n_configs x config_size
        # or shape N x n_configs x config_size
        # we want network output that is a matrix for each of these
        # and we will single-batch everything
        # so N*M x num_basis x config_size

        # it must be permutation-equivariant, so give same output
        # even after swapping config dims.
        # configs may be of different sizes
        # let's just say we have an NN that takes any single
        # configuration and spits out a vector of size num_basis

        # then take determinant
        self.num_basis = num_basis
        self.config_size = config_size
        self.n_configs = n_configs
        self.network = nn.Linear(self.config_size, self.num_basis)

    def forward(self, x):

        x_shape = list(x.shape)
        assert x_shape[-1] == self.config_size
        assert x_shape[-2] == self.n_configs
        x_flat = x.view(-1, self.config_size)

        out_flat = self.network(x_flat)
        out_reshaped = out_flat.view(x_shape[:-1] + [self.num_basis])

        dets = torch.det(out_reshaped)

        return dets


class NNVandermonde(nn.Module):
    def __init__(self, n_configs, config_size):
        super(NNVandermonde, self).__init__()

        # configurations of shape N x M x n_configs x config_size
        # or shape N x n_configs x config_size
        # we want output for Vandermonde of size N x M x n_configs
        # a scalar per configuration subcomponent
        # then we will do Vandermonde


        # then take determinant
        self.config_size = config_size
        self.n_configs = n_configs
        self.network = nn.Sequential(nn.Linear(self.config_size, 1))

    def forward(self, x):

        x_shape = list(x.shape)
        assert x_shape[-1] == self.config_size
        assert x_shape[-2] == self.n_configs
        x_flat = x.view(-1, self.config_size)

        out_flat = self.network(x_flat)
        out_reshaped = out_flat.view(x_shape[:-1])

        # slow bad loops, baby! let's waste some cycles!
        products = 1.0
        for i in range(out_reshaped.shape[-1]):
            for j in range(i+1, out_reshaped.shape[-1]):
                products = products * (out_reshaped[..., i] - out_reshaped[..., j])
        return products
        # log_sumproducts = 0.0
        # for i in range(out_reshaped.shape[-1]):
        #     for j in range(i+1, out_reshaped.shape[-1]):
        #         log_sumproducts = log_sumproducts + torch.log(out_reshaped[..., i] - out_reshaped[..., j])

        # return log_sumproducts

In [20]:
ff = NNSlaterDet(3,2,3)

In [21]:
xx = torch.rand(10,3,2)
permuted_xx = xx[:, [0,2,1], :]

In [22]:
ff(xx)


tensor([ 1.6894e-03,  7.3975e-04, -2.5405e-03, -5.9339e-04, -1.3896e-03,
         7.9457e-05,  1.0185e-03,  3.8652e-04,  2.0074e-04, -2.5871e-04],
       grad_fn=<DetBackward>)

In [23]:
ff(permuted_xx)

tensor([-1.6894e-03, -7.3975e-04,  2.5405e-03,  5.9339e-04,  1.3896e-03,
        -7.9457e-05, -1.0185e-03, -3.8652e-04, -2.0074e-04,  2.5871e-04],
       grad_fn=<DetBackward>)

In [55]:
ff2 = NNVandermonde(3,2)

In [56]:
ff2(xx)

tensor([-4.4096e-03,  1.0432e-03,  5.8990e-03,  6.9865e-04, -1.0369e-02,
         7.2225e-05, -1.1288e-03, -3.3655e-04, -2.4759e-05, -2.5304e-03],
       grad_fn=<MulBackward0>)

In [57]:
ff2(permuted_xx)

tensor([ 4.4096e-03, -1.0432e-03, -5.8990e-03, -6.9865e-04,  1.0369e-02,
        -7.2225e-05,  1.1288e-03,  3.3655e-04,  2.4759e-05,  2.5304e-03],
       grad_fn=<MulBackward0>)