In [18]:
import torch
from torch import nn

class ResidualGate(nn.Module):
    def __init__(self,
            d_model: int,
            gate_application: str = 'reset-update', # 'reset-update' or 'reset' or 'update' or 'combined' or 'none'
            gate_compute: str = 'linear-bias', # 'linear-bias' or 'linear' or 'bias'
            gate_activation: str = 'sigmoid', # 'sigmoid' or 'tanh' or 'none
            ):
        super(ResidualGate, self).__init__()

        self.d_model = d_model
        self.gate_application = gate_application
        self.gate_compute = gate_compute
        self.gate_activation = gate_activation
        self.gate_activation_fn = nn.Sigmoid() if gate_activation == 'sigmoid' else nn.Tanh() if gate_activation == 'tanh' else None

        bias = gate_compute == 'linear-bias'

        if gate_compute in ['linear-bias', 'linear']:
            if gate_application == 'reset-update':
                self.update_gate_linear = nn.Linear(d_model, d_model, bias=bias)
                self.reset_gate_linear = nn.Linear(d_model, d_model, bias=bias)
            elif gate_application == 'reset' or gate_application == 'combined':
                self.reset_gate_linear = nn.Linear(d_model, d_model, bias=bias)
            elif gate_application == 'update':
                self.update_gate_linear = nn.Linear(d_model, d_model, bias=bias)
            elif gate_application == 'none':
                pass
            else:
                raise ValueError(f'Unknown gate_application: {gate_application}')

        elif gate_compute == 'bias':
            if gate_application == 'reset-update':
                self.update_gate_bias = nn.Parameter(torch.zeros(d_model))
                self.reset_gate_bias = nn.Parameter(torch.zeros(d_model))
            elif gate_application == 'reset' or gate_application == 'combined':
                self.reset_gate_bias = nn.Parameter(torch.zeros(d_model))
                # in the combined case, the reset gate is used to compute g*x + (1-g)*y
            elif gate_application == 'update':
                self.update_gate_bias = nn.Parameter(torch.zeros(d_model))
            elif gate_application == 'none':
                pass
            else:
                raise ValueError(f'Unknown gate_application: {gate_application}')

    # TODO: bias initialization (non-zero)

    def _compute_update_gate(self, x):
        if self.gate_compute in ('linear', 'linear-bias'):
            update_gate = self.gate_activation_fn(self.update_gate_linear(x))
        elif self.gate_compute == 'bias':
            update_gate = self.gate_activation_fn(torch.zeros_like(x) + self.update_gate_bias)
        else:
            raise ValueError(f'Unknown gate_compute: {self.gate_compute}')

        return update_gate

    def _compute_reset_gate(self, x):
        if self.gate_compute in ('linear', 'linear-bias'):
            reset_gate = self.gate_activation_fn(self.reset_gate_linear(x))
        elif self.gate_compute == 'bias':
            reset_gate = self.gate_activation_fn(torch.zeros_like(x) + self.reset_gate_bias)
        else:
            raise ValueError(f'Unknown gate_compute: {self.gate_compute}')

        return reset_gate


    def forward(self, x, y):
        if self.gate_application == 'none':
            z = x + y
        elif self.gate_application == 'update':
            update_gate = self._compute_update_gate(x)
            z = update_gate * x + y
        elif self.gate_application == 'reset':
            reset_gate = self._compute_reset_gate(x)
            z = reset_gate * x + y
        elif self.gate_application == 'reset-update':
            update_gate = self._compute_update_gate(x)
            reset_gate = self._compute_reset_gate(x)
            z = reset_gate * x + update_gate * y
        elif self.gate_application == 'combined':
            gate = self._compute_reset_gate(x)
            z = gate * x + (1 - gate) * y
        else:
            raise ValueError(f'Unknown gate_application: {self.gate_application}')

        return z

# TODO: all the gating mechanisms above are x-dependent but not y-dependent
# in LSTM, for e.g., gates are both x and y deppendent. i.e., gate = sigmoid(Wx + Uy + b)

In [19]:
import torchinfo

In [20]:
input_size = (1, 10, 64)
x, y = torch.randn(input_size), torch.randn(input_size)

# Possible values for gate_application and gate_compute
gate_application_options = ['reset-update', 'reset', 'update', 'combined', 'none']
gate_compute_options = ['linear-bias', 'linear', 'bias']

# Iterate over each combination
for gate_application in gate_application_options:
    for gate_compute in gate_compute_options:
        print(f"Summary for gate_application={gate_application}, gate_compute={gate_compute}:")
        model = ResidualGate(64, gate_application, gate_compute)
        # Assuming input size (batch_size, channels, height, width)
        print(torchinfo.summary(model, input_data=[x, y]))
        print("\n")

Summary for gate_application=reset-update, gate_compute=linear-bias:
update_gate.shape: torch.Size([1, 10, 64])
reset_gate.shape: torch.Size([1, 10, 64])
z.shape: torch.Size([1, 10, 64])
Layer (type:depth-idx)                   Output Shape              Param #
ResidualGate                             [1, 10, 64]               --
├─Linear: 1-1                            [1, 10, 64]               4,160
├─Sigmoid: 1-2                           [1, 10, 64]               --
├─Linear: 1-3                            [1, 10, 64]               4,160
├─Sigmoid: 1-4                           [1, 10, 64]               --
Total params: 8,320
Trainable params: 8,320
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.01
Input size (MB): 0.01
Forward/backward pass size (MB): 0.01
Params size (MB): 0.03
Estimated Total Size (MB): 0.05


Summary for gate_application=reset-update, gate_compute=linear:
update_gate.shape: torch.Size([1, 10, 64])
reset_gate.shape: torch.Size([1, 10, 64])
z.shape:

In [16]:
torchinfo.summary?

[0;31mSignature:[0m
[0mtorchinfo[0m[0;34m.[0m[0msummary[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mmodel[0m[0;34m:[0m [0;34m'nn.Module'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minput_size[0m[0;34m:[0m [0;34m'INPUT_SIZE_TYPE | None'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minput_data[0m[0;34m:[0m [0;34m'INPUT_DATA_TYPE | None'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbatch_dim[0m[0;34m:[0m [0;34m'int | None'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcache_forward_pass[0m[0;34m:[0m [0;34m'bool | None'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcol_names[0m[0;34m:[0m [0;34m'Iterable[str] | None'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcol_width[0m[0;34m:[0m [0;34m'int'[0m [0;34m=[0m [0;36m25[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdepth[0m[0;34m:[0m [0;34m'int'[0m [