In [101]:
import torch
from torch import Tensor, nn
from torch.nn import Module, Parameter
from torch.nn import functional as F

class ExponentialDecay(Module):
    def __init__(self, in_dim, out_dim, n_units, n_components) -> None:
        super().__init__()
        
        self.decay_rates = nn.Parameter(torch.rand(n_components, 1)) # (n_components, 1)
        self.roll_out = torch.arange(out_dim).unsqueeze(0) # (1, out_dim)
        self.layers = nn.Sequential(
            nn.Linear(in_dim, n_units),
            nn.GELU(),
            nn.Linear(n_units, n_components),
            nn.ReLU(), # must be postive
        ) # output: (bs, n_components)

    def forward(self, x: Tensor) -> Tensor:
        c = self.layers(x)
        base = torch.exp(-self.decay_rates * self.roll_out) # (n_components, out_dim)
        y = c @ base # (bs, out_dim)
        return y



In [102]:
bs = 2
in_dim = 20
out_dim = 5
n_units = 20
n_components = 3
x = torch.randn(bs, in_dim)

In [103]:
model = ExponentialDecay(in_dim, out_dim, n_units, n_components)

In [104]:
x = model(x)

In [105]:
x

tensor([[0.5630, 0.3370, 0.2018, 0.1208, 0.0723],
        [0.1488, 0.0635, 0.0271, 0.0116, 0.0049]], grad_fn=<MmBackward0>)

In [91]:
r0 @ base

tensor([[ 3.0000,  0.6386,  0.1720,  0.0547,  0.0190],
        [12.0000,  5.1482,  3.6329,  3.2115,  3.0749]])