#### 1. Design a layer that takes an input and computes a tensor reduction, i.e., it returns $y_k = \sum­_{i, j} W_{ijk} x_i x_j$

We do functional reductin to multiply the $n$ inputs to obtain $n^2$ nodes consisting of $x_i x_j$ (could reduce the amount of computation by noting that $x_i = x_j$). Also see Exercise 3.1.3.

In [1]:
import torch
from torch import Tensor, nn

# This is a prime candidate for a type of LazyQuadratic Mixin functionality
# as we don't really care how many inputs we have for this layer.
class QuadraticFeatures(nn.Module):
    """
    Takes in n inputs and has n * (n + 1) / 2 outputs, where the outputs are
    x_i * x_j for all 1 <= i < j <= n.
    """

    def __init__(self, num_in: int):
        super().__init__()
        self.num_in = num_in
        self.num_interactions = self.num_in * (self.num_in + 1) // 2
        self.weights = nn.Parameter(torch.randn(self.num_interactions))

    def forward(self, x: Tensor) -> Tensor:
        # num_in.shape = n => forward(x).shape = n * (n + 1) / 2
        if x.shape != (self.num_in,):
            raise ValueError(f"Expected input shape {self.num_in}, but got {x.shape}")
        return torch.concat([(x[i].view((-1, 1)) * x[i:])[0] for i in range(x.shape[0])]).float()

class TensorReducer(nn.Module):
    def __init__(self, num_in: int, num_out: int):
        super().__init__()
        # Ideally we wouldn't even have to specify num_in here, maybe we can
        # just skip it? I'll leave it in for now, as I don't know if Modules
        # keep track of some kind of internal state for the inputs, maybe
        # one has to properly initialize the module with the correct number
        # of inputs, or tell it that the forward does a delayed init.
        self.quadratic_features = QuadraticFeatures(num_in)
        self.linear = nn.LazyLinear(num_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (x_1, ..., x_n) -> (x_i * x_j : 1 <= i < j <= n) -> (y_1, ..., y_m), n = num_in, m = num_out
        x = self.quadratic_features(x)
        x = self.linear(x)
        return x

my_layer = TensorReducer(4, 3)
my_layer(torch.tensor([1.0, 2.0, 3.0, 4.0]))
print(f"Note that {(4 * (4 + 1) // 2)=}, which is where the 10 comes from.")
print()
for name, param in my_layer.named_parameters():
    print(name)
    print(param)
    print(param.shape)
    print()

Note that (4 * (4 + 1) // 2)=10, which is where the 10 comes from.

quadratic_features.weights
Parameter containing:
tensor([-0.1705, -0.0158, -0.5977, -0.9266,  0.1062,  2.0876, -0.5053,  0.0987,
        -1.4225, -0.9952], requires_grad=True)
torch.Size([10])

linear.weight
Parameter containing:
tensor([[ 0.1196,  0.1580, -0.0579, -0.1071,  0.2171,  0.0097, -0.2498, -0.1591,
          0.0797,  0.2451],
        [ 0.2625, -0.0648, -0.0801,  0.1312,  0.0295,  0.0800, -0.0326,  0.2585,
         -0.1507, -0.1237],
        [ 0.2545,  0.0711, -0.0905, -0.1273, -0.1730, -0.2179, -0.3147, -0.0760,
          0.2336, -0.0460]], requires_grad=True)
torch.Size([3, 10])

linear.bias
Parameter containing:
tensor([ 0.1557,  0.2839, -0.1932], requires_grad=True)
torch.Size([3])

