In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # Input-to-hidden
        self.W_z = nn.Parameter(torch.randn(hidden_size, input_size) * 0.01)
        self.W_r = nn.Parameter(torch.randn(hidden_size, input_size) * 0.01)
        self.W_h = nn.Parameter(torch.randn(hidden_size, input_size) * 0.01)

        # Hidden-to-hidden
        self.U_z = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)
        self.U_r = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)
        self.U_h = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)

        # Biases
        self.b_z = nn.Parameter(torch.zeros(hidden_size))
        self.b_r = nn.Parameter(torch.zeros(hidden_size))
        self.b_h = nn.Parameter(torch.zeros(hidden_size))

    def forward(self, x_t, h_prev):
        # x_t: (batch, input_size)
        # h_prev: (batch, hidden_size)

        z_t = torch.sigmoid(
            x_t @ self.W_z.T + h_prev @ self.U_z.T + self.b_z
        )
        r_t = torch.sigmoid(
            x_t @ self.W_r.T + h_prev @ self.U_r.T + self.b_r
        )

        h_tilde = torch.tanh(
            x_t @ self.W_h.T + (r_t * h_prev) @ self.U_h.T + self.b_h
        )

        h_t = (1 - z_t) * h_prev + z_t * h_tilde
        return h_t

class GRU(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.cell = GRUCell(input_size, hidden_size)

    def forward(self, x, h0=None):
        # x: (batch, seq_len, input_size)
        batch_size, seq_len, _ = x.shape

        if h0 is None:
            h_t = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h_t = h0

        outputs = []
        for t in range(seq_len):
            x_t = x[:, t, :]          # (batch, input_size)
            h_t = self.cell(x_t, h_t) # (batch, hidden_size)
            outputs.append(h_t.unsqueeze(1))

        # Concatenate over time
        outputs = torch.cat(outputs, dim=1)  # (batch, seq_len, hidden_size)
        return outputs, h_t