In [15]:
import torch
from torch import Tensor, nn
from torch.nn import Module, Parameter
from torch.nn import functional as F

class ExponentialDecay(Module):
    def __init__(self, in_dim, out_dim, n_units, n_components) -> None:
        super().__init__()
        
        self.decay_rates = nn.Parameter(torch.rand(n_components, 1)) # (n_components, 1)
        self.roll_out = nn.Parameter(torch.arange(out_dim).unsqueeze(0), requires_grad=False) # (1, out_dim)
        self.layers = nn.Sequential(
            nn.Linear(in_dim, n_units),
            nn.GELU(),
            nn.Linear(n_units, n_components),
            nn.ReLU(), # must be postive
        ) # output: (bs, n_components)

    def forward(self, x: Tensor) -> Tensor:
        c = self.layers(x)
        base = torch.exp(-self.decay_rates * self.roll_out) # (n_components, out_dim)
        y = c @ base # (bs, out_dim)
        return y



In [16]:
bs = 2
in_dim = 20
out_dim = 5
n_units = 20
n_components = 3
x = torch.randn(bs, in_dim)

In [17]:
model = ExponentialDecay(in_dim, out_dim, n_units, n_components)

In [18]:
list(model.parameters())

[Parameter containing:
 tensor([[0.7662],
         [0.2885],
         [0.3969]], requires_grad=True),
 Parameter containing:
 tensor([[0, 1, 2, 3, 4]]),
 Parameter containing:
 tensor([[ 1.8829e-01, -8.7000e-02, -1.9602e-01, -1.6427e-02, -4.6554e-02,
          -1.1720e-01,  1.7484e-01, -2.1821e-01,  1.9922e-01,  4.5252e-02,
          -2.6194e-02, -1.3156e-01, -9.6790e-02, -4.2549e-02, -2.0596e-01,
           1.4484e-01, -1.8550e-01,  4.9998e-02,  2.1038e-01, -3.8049e-02],
         [ 3.9169e-02, -7.2638e-02,  2.1693e-01, -1.9897e-01,  1.1179e-01,
           4.2406e-02,  1.5575e-01,  1.7575e-01, -1.4472e-01, -1.4603e-01,
           5.7868e-02,  7.8792e-02,  3.8275e-02, -1.1662e-01,  6.7328e-02,
          -1.5888e-01,  5.4873e-02, -1.5217e-01, -8.5088e-02, -1.7898e-01],
         [ 2.1439e-01,  1.7532e-01, -2.1333e-01,  1.7507e-01,  9.3965e-02,
           3.8418e-02,  1.2424e-01, -4.3539e-02,  1.1170e-01,  1.1077e-01,
          -7.9096e-02,  1.0794e-01, -2.2937e-02, -5.4485e-02,  1.0842e-0

In [19]:
[p for p in model.parameters() if p.requires_grad]

[Parameter containing:
 tensor([[0.7662],
         [0.2885],
         [0.3969]], requires_grad=True),
 Parameter containing:
 tensor([[ 1.8829e-01, -8.7000e-02, -1.9602e-01, -1.6427e-02, -4.6554e-02,
          -1.1720e-01,  1.7484e-01, -2.1821e-01,  1.9922e-01,  4.5252e-02,
          -2.6194e-02, -1.3156e-01, -9.6790e-02, -4.2549e-02, -2.0596e-01,
           1.4484e-01, -1.8550e-01,  4.9998e-02,  2.1038e-01, -3.8049e-02],
         [ 3.9169e-02, -7.2638e-02,  2.1693e-01, -1.9897e-01,  1.1179e-01,
           4.2406e-02,  1.5575e-01,  1.7575e-01, -1.4472e-01, -1.4603e-01,
           5.7868e-02,  7.8792e-02,  3.8275e-02, -1.1662e-01,  6.7328e-02,
          -1.5888e-01,  5.4873e-02, -1.5217e-01, -8.5088e-02, -1.7898e-01],
         [ 2.1439e-01,  1.7532e-01, -2.1333e-01,  1.7507e-01,  9.3965e-02,
           3.8418e-02,  1.2424e-01, -4.3539e-02,  1.1170e-01,  1.1077e-01,
          -7.9096e-02,  1.0794e-01, -2.2937e-02, -5.4485e-02,  1.0842e-01,
          -8.1278e-03, -2.0964e-01, -2.1471e-01,

tensor([[0.5630, 0.3370, 0.2018, 0.1208, 0.0723],
        [0.1488, 0.0635, 0.0271, 0.0116, 0.0049]], grad_fn=<MmBackward0>)

In [91]:
r0 @ base

tensor([[ 3.0000,  0.6386,  0.1720,  0.0547,  0.0190],
        [12.0000,  5.1482,  3.6329,  3.2115,  3.0749]])