In [13]:
import torch
import torch.nn as nn

class TensorDimReductionLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.weights = nn.Parameter(torch.randn(input_dim, input_dim, output_dim))

    def forward(self, x):
        batch_size = x.size(0)
        outer_product = torch.bmm(x.unsqueeze(2), x.unsqueeze(1))
        outer_product = outer_product.reshape(batch_size, -1)
        W = self.weights.reshape(-1, self.output_dim)
        y = torch.matmul(outer_product, W)
        return y

# Example 
input_dim = 4
output_dim = 3
layer = TensorDimReductionLayer(input_dim, output_dim)
x = torch.randn(5, input_dim)
y = layer(x)
print(y)

tensor([[-1.2104,  4.7954, -0.5102],
        [-0.4499, -0.8049,  2.7554],
        [-4.9088, 11.2100,  4.1863],
        [-0.3886,  0.7023, -0.6577],
        [-1.9788,  0.9281,  1.0598]], grad_fn=<MmBackward0>)
