In [1]:
from typing import Tuple, Optional

import torch
import torch.nn.functional as F
from torch import nn
from torch import layer_norm

In [14]:
from turtle import forward


class RHNCell(nn.Module):

    def __init__(self, input_dim: int, hidden_dim: int, depth: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.depth = depth

        self.W_x = nn.Linear(input_dim, 2 * hidden_dim, bias = False)
        self.W_h = nn.ModuleList([
            nn.Linear(hidden_dim, 2 * hidden_dim)
            for _ in range(depth)
            ]) # D linear transformations

    def forward(self, x: torch.Tensor, s: torch.Tensor):
        """
        x: batch_size, input_dim
        s: batch_size, hidden_dim

        c = 1 - g as described in paper
        """
        for d in range(self.depth):
            h_and_g = self.W_h[d](s) if d == 0 \
                else self.W_h[d](s) + self.W_x(x)
            h = F.tanh(h_and_g[:, :self.hidden_dim]) # or use h, g = h_and_g.chunk(hidden_dim, dim = -1)
            g = F.sigmoid(h_and_g[:, self.hidden_dim:])

        return h * g + s * (1 - g)

        

In [None]:
class RHN(nn.Module):

    def __init__(self, input_dim: int, hidden_dim: int, n_layers: int, depth: int, return_sequences: bool = True, batch_first = True) -> None:
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.return_sequences = return_sequences
        self.batch_first = batch_first

        self.cells = nn.ModuleList([RHNCell(input_dim, hidden_dim, depth)] +
                                    [RHNCell(hidden_dim, hidden_dim, depth) for _ in range(self.n_layers - 1)])

    def forward(self, x: torch.Tensor, s: torch.Tensor):
        """
        x : (batch_size, n_steps, feature_dim) if batch_first
        else (n_steps, batch_size, feature_dim)
        s : (n_layers, feature_dim)
        """
        if self.batch_first:
            x = x.transpose(0, 1)

        n_steps, batch_size = x.shape[:2]
        s = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
        out = []
        
        for t in range(n_steps):
            inputs = x[t]
            for l in range(self.n_layers):
                s[l] = self.cells[l](inputs, s[l])
                inputs = s[l]

            out.append(s[-1])
        
        if not self.return_sequences: out = out[-1]

        return out, s

In [11]:
hg = torch.Tensor([\
    [1,2,3,4],
    [5,6,7,8]])

hg.chunk(2, dim = -1)[0].shape

torch.Size([2, 2])

In [13]:
hg[:, :2].shape

torch.Size([2, 2])