In [22]:
"""
This notebook walks through the MonomialLayer class of polynomial.py.
Updated 3/9/2023
"""

'\nThis notebook walks through the MonomialLayer class of polynomial.py.\nUpdated 3/9/2023\n'

In [23]:
import torch
from torch import nn
import math

In [24]:
# Current definition of MonomialLayer
class MonomialLayer(nn.Module):
    """Outputs all possible monomials up to given degree from inpupts.

    The basic idea is to add the number 1 to the list of inputs and
    then create every possible monomial of the given degree from
    factors of the inputs and one.  The presence of one in the list
    generates the monomials with degree less than the given degree.
    See the math of multisets for more information.  (Wikipedia has a
    good entry on this.

    Attributes:
       n_inputs: The number of inputs this layer expects
       degree: The maximum degree of the momomial
       n_outputs: The number of monomials output

    """
    def __init__(self, n_inputs: int, degree: int) -> None:
        """Constructor

        Args:
            n_inputs: The number of input variables to the layer
            degree: The maximum degree of the monomials
        """
        
        super(MonomialLayer,self).__init__()

        self.n_inputs = n_inputs
        self.degree = degree
        # No point in keeping the constant term
        self.n_outputs = int(math.factorial(n_inputs+degree) /
                          math.factorial(n_inputs) /
                          math.factorial(degree)) - 1

        # Now, let's build an array of the indices of the inputs that
        # need to be combined for each monomial
        self.m_ind = torch.zeros(self.n_outputs,degree,dtype=torch.int32).cuda()
        curr_ind = torch.zeros(self.degree, dtype=torch.int32).cuda()

        for row in range(self.n_outputs):
            # Calculate the values for this row
            for col in range(self.degree-1,0,-1):
                if curr_ind[col-1] > curr_ind[col]:
                    curr_ind[col]+=1
                    break
                else:
                    curr_ind[col]=0
            else:
                curr_ind[0]+=1
                curr_ind[1:]=0 # Broadcasts!

            # Set the indices for this row
            self.m_ind[row,:] =  curr_ind

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        """Performs forward propagation for this module

        Args: 
            x: 
              The input tensor to this layer.  This function preserves
              all indices except the last one, which is assumed to
              index the input variables.  Leading indices can be used
              for minibatches or structuring the inputs.

        """
        
        if x.shape[-1] != self.n_inputs:
            raise IndexError(f'Expecting {self.n_inputs} inputs, got {x.shape[-1]}')
        x = torch.cat((torch.ones(x.shape[:-1]+(1,)).cuda(),x),axis=-1)
        return torch.prod(torch.index_select(x,-1,self.m_ind.flatten())
                          .reshape(x.shape[:-1]+self.m_ind.shape),axis=-1)


In [25]:
# Quick demonstration
mono = MonomialLayer(16, 2)
x = torch.tensor([[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]]).cuda()
mono.forward(x)

tensor([[  1.,   1.,   2.,   2.,   4.,   3.,   3.,   6.,   9.,   4.,   4.,   8.,
          12.,  16.,   5.,   5.,  10.,  15.,  20.,  25.,   6.,   6.,  12.,  18.,
          24.,  30.,  36.,   7.,   7.,  14.,  21.,  28.,  35.,  42.,  49.,   8.,
           8.,  16.,  24.,  32.,  40.,  48.,  56.,  64.,   9.,   9.,  18.,  27.,
          36.,  45.,  54.,  63.,  72.,  81.,  10.,  10.,  20.,  30.,  40.,  50.,
          60.,  70.,  80.,  90., 100.,  11.,  11.,  22.,  33.,  44.,  55.,  66.,
          77.,  88.,  99., 110., 121.,  12.,  12.,  24.,  36.,  48.,  60.,  72.,
          84.,  96., 108., 120., 132., 144.,  13.,  13.,  26.,  39.,  52.,  65.,
          78.,  91., 104., 117., 130., 143., 156., 169.,  14.,  14.,  28.,  42.,
          56.,  70.,  84.,  98., 112., 126., 140., 154., 168., 182., 196.,  15.,
          15.,  30.,  45.,  60.,  75.,  90., 105., 120., 135., 150., 165., 180.,
         195., 210., 225.,  16.,  16.,  32.,  48.,  64.,  80.,  96., 112., 128.,
         144., 160., 176., 1

In [26]:
# Walk through little by little
n_inputs = 16
degree = 2
# No point in keeping the constant term
n_outputs = int(math.factorial(n_inputs+degree) /
                          math.factorial(n_inputs) /
                          math.factorial(degree)) - 1

In [27]:
# Now, let's build an array of the indices of the inputs that
        # need to be combined for each monomial
m_ind = torch.zeros(n_outputs,degree,dtype=torch.int32).cuda()
curr_ind = torch.zeros(degree, dtype=torch.int32).cuda()
print (m_ind)
print (curr_ind)

tensor([[0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        

In [28]:
for row in range(n_outputs):
    # Calculate the values for this row
    for col in range(degree-1,0,-1):
        if curr_ind[col-1] > curr_ind[col]:
            curr_ind[col]+=1
            break
        else:
            curr_ind[col]=0
    else:
        curr_ind[0]+=1
        curr_ind[1:]=0 # Broadcasts!

    # Set the indices for this row
    m_ind[row,:] =  curr_ind

In [29]:
m_ind

tensor([[ 1,  0],
        [ 1,  1],
        [ 2,  0],
        [ 2,  1],
        [ 2,  2],
        [ 3,  0],
        [ 3,  1],
        [ 3,  2],
        [ 3,  3],
        [ 4,  0],
        [ 4,  1],
        [ 4,  2],
        [ 4,  3],
        [ 4,  4],
        [ 5,  0],
        [ 5,  1],
        [ 5,  2],
        [ 5,  3],
        [ 5,  4],
        [ 5,  5],
        [ 6,  0],
        [ 6,  1],
        [ 6,  2],
        [ 6,  3],
        [ 6,  4],
        [ 6,  5],
        [ 6,  6],
        [ 7,  0],
        [ 7,  1],
        [ 7,  2],
        [ 7,  3],
        [ 7,  4],
        [ 7,  5],
        [ 7,  6],
        [ 7,  7],
        [ 8,  0],
        [ 8,  1],
        [ 8,  2],
        [ 8,  3],
        [ 8,  4],
        [ 8,  5],
        [ 8,  6],
        [ 8,  7],
        [ 8,  8],
        [ 9,  0],
        [ 9,  1],
        [ 9,  2],
        [ 9,  3],
        [ 9,  4],
        [ 9,  5],
        [ 9,  6],
        [ 9,  7],
        [ 9,  8],
        [ 9,  9],
        [10,  0],
        [1

In [30]:
x.shape[:-1]

torch.Size([1])

In [37]:
x.shape[:-1]+(1,) # concatenated

torch.Size([1, 1])

In [39]:
x = torch.cat((torch.ones(x.shape[:-1]+(1,)).cuda(),x),axis=-1)

In [40]:
x

tensor([[ 1.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
         14., 15., 16.]], device='cuda:0')

In [41]:
m_ind.flatten()

tensor([ 1,  0,  1,  1,  2,  0,  2,  1,  2,  2,  3,  0,  3,  1,  3,  2,  3,  3,
         4,  0,  4,  1,  4,  2,  4,  3,  4,  4,  5,  0,  5,  1,  5,  2,  5,  3,
         5,  4,  5,  5,  6,  0,  6,  1,  6,  2,  6,  3,  6,  4,  6,  5,  6,  6,
         7,  0,  7,  1,  7,  2,  7,  3,  7,  4,  7,  5,  7,  6,  7,  7,  8,  0,
         8,  1,  8,  2,  8,  3,  8,  4,  8,  5,  8,  6,  8,  7,  8,  8,  9,  0,
         9,  1,  9,  2,  9,  3,  9,  4,  9,  5,  9,  6,  9,  7,  9,  8,  9,  9,
        10,  0, 10,  1, 10,  2, 10,  3, 10,  4, 10,  5, 10,  6, 10,  7, 10,  8,
        10,  9, 10, 10, 11,  0, 11,  1, 11,  2, 11,  3, 11,  4, 11,  5, 11,  6,
        11,  7, 11,  8, 11,  9, 11, 10, 11, 11, 12,  0, 12,  1, 12,  2, 12,  3,
        12,  4, 12,  5, 12,  6, 12,  7, 12,  8, 12,  9, 12, 10, 12, 11, 12, 12,
        13,  0, 13,  1, 13,  2, 13,  3, 13,  4, 13,  5, 13,  6, 13,  7, 13,  8,
        13,  9, 13, 10, 13, 11, 13, 12, 13, 13, 14,  0, 14,  1, 14,  2, 14,  3,
        14,  4, 14,  5, 14,  6, 14,  7, 

In [42]:
torch.index_select(x,-1,m_ind.flatten())

tensor([[ 1.,  1.,  1.,  1.,  2.,  1.,  2.,  1.,  2.,  2.,  3.,  1.,  3.,  1.,
          3.,  2.,  3.,  3.,  4.,  1.,  4.,  1.,  4.,  2.,  4.,  3.,  4.,  4.,
          5.,  1.,  5.,  1.,  5.,  2.,  5.,  3.,  5.,  4.,  5.,  5.,  6.,  1.,
          6.,  1.,  6.,  2.,  6.,  3.,  6.,  4.,  6.,  5.,  6.,  6.,  7.,  1.,
          7.,  1.,  7.,  2.,  7.,  3.,  7.,  4.,  7.,  5.,  7.,  6.,  7.,  7.,
          8.,  1.,  8.,  1.,  8.,  2.,  8.,  3.,  8.,  4.,  8.,  5.,  8.,  6.,
          8.,  7.,  8.,  8.,  9.,  1.,  9.,  1.,  9.,  2.,  9.,  3.,  9.,  4.,
          9.,  5.,  9.,  6.,  9.,  7.,  9.,  8.,  9.,  9., 10.,  1., 10.,  1.,
         10.,  2., 10.,  3., 10.,  4., 10.,  5., 10.,  6., 10.,  7., 10.,  8.,
         10.,  9., 10., 10., 11.,  1., 11.,  1., 11.,  2., 11.,  3., 11.,  4.,
         11.,  5., 11.,  6., 11.,  7., 11.,  8., 11.,  9., 11., 10., 11., 11.,
         12.,  1., 12.,  1., 12.,  2., 12.,  3., 12.,  4., 12.,  5., 12.,  6.,
         12.,  7., 12.,  8., 12.,  9., 12., 10., 12.

In [46]:
x.shape[:-1]+m_ind.shape

torch.Size([1, 152, 2])

In [45]:
torch.index_select(x,-1,m_ind.flatten()).reshape(x.shape[:-1]+m_ind.shape)

tensor([[[ 1.,  1.],
         [ 1.,  1.],
         [ 2.,  1.],
         [ 2.,  1.],
         [ 2.,  2.],
         [ 3.,  1.],
         [ 3.,  1.],
         [ 3.,  2.],
         [ 3.,  3.],
         [ 4.,  1.],
         [ 4.,  1.],
         [ 4.,  2.],
         [ 4.,  3.],
         [ 4.,  4.],
         [ 5.,  1.],
         [ 5.,  1.],
         [ 5.,  2.],
         [ 5.,  3.],
         [ 5.,  4.],
         [ 5.,  5.],
         [ 6.,  1.],
         [ 6.,  1.],
         [ 6.,  2.],
         [ 6.,  3.],
         [ 6.,  4.],
         [ 6.,  5.],
         [ 6.,  6.],
         [ 7.,  1.],
         [ 7.,  1.],
         [ 7.,  2.],
         [ 7.,  3.],
         [ 7.,  4.],
         [ 7.,  5.],
         [ 7.,  6.],
         [ 7.,  7.],
         [ 8.,  1.],
         [ 8.,  1.],
         [ 8.,  2.],
         [ 8.,  3.],
         [ 8.,  4.],
         [ 8.,  5.],
         [ 8.,  6.],
         [ 8.,  7.],
         [ 8.,  8.],
         [ 9.,  1.],
         [ 9.,  1.],
         [ 9.,  2.],
         [ 9.

In [48]:
torch.prod(torch.index_select(x,-1,m_ind.flatten()).reshape(x.shape[:-1]+m_ind.shape),axis=-1)

tensor([[  1.,   1.,   2.,   2.,   4.,   3.,   3.,   6.,   9.,   4.,   4.,   8.,
          12.,  16.,   5.,   5.,  10.,  15.,  20.,  25.,   6.,   6.,  12.,  18.,
          24.,  30.,  36.,   7.,   7.,  14.,  21.,  28.,  35.,  42.,  49.,   8.,
           8.,  16.,  24.,  32.,  40.,  48.,  56.,  64.,   9.,   9.,  18.,  27.,
          36.,  45.,  54.,  63.,  72.,  81.,  10.,  10.,  20.,  30.,  40.,  50.,
          60.,  70.,  80.,  90., 100.,  11.,  11.,  22.,  33.,  44.,  55.,  66.,
          77.,  88.,  99., 110., 121.,  12.,  12.,  24.,  36.,  48.,  60.,  72.,
          84.,  96., 108., 120., 132., 144.,  13.,  13.,  26.,  39.,  52.,  65.,
          78.,  91., 104., 117., 130., 143., 156., 169.,  14.,  14.,  28.,  42.,
          56.,  70.,  84.,  98., 112., 126., 140., 154., 168., 182., 196.,  15.,
          15.,  30.,  45.,  60.,  75.,  90., 105., 120., 135., 150., 165., 180.,
         195., 210., 225.,  16.,  16.,  32.,  48.,  64.,  80.,  96., 112., 128.,
         144., 160., 176., 1