<a href="https://colab.research.google.com/github/DavoodSZ1993/Dive_into_Deep_Learning/blob/main/10_2_GRU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install d2l==1.0.0-alpha1.post0 --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.0/93.0 KB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.0/121.0 KB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m83.6/83.6 KB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m51.0 MB/s[0m eta [36m0:00:00[0m
[?25h

## 10.2 Gated Recurrent Units (GRU)

### 10.2.4 Implementation from Scratch

In [2]:
import torch
from torch import nn
from d2l import torch as d2l

#### Initializing Model Parameters

In [3]:
class GRUScratch(d2l.Module):
  def __init__(self, num_inputs, num_hiddens, sigma=0.01):
    super().__init__()
    self.save_hyperparameters()

    init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)
    triple = lambda: (init_weight(num_inputs, num_hiddens),
                      init_weight(num_hiddens, num_hiddens),
                      nn.Parameter(torch.zeros(num_hiddens)))
    self.W_xz, self.W_hz, self.b_z = triple() # Update gate
    self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
    self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state

#### Defining the Model

In [4]:
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
  if H is None:
    # Initial state with shape: (batch_size, num_hiddens)
    H = torch.zeros((inputs.shape[1], self.num_hiddens), device=inputs.device)

  outputs = []
  for X in inputs: 
    Z = torch.sigmoid(torch.matmul(X, self.W_xz) +
                      torch.matmul(H, self.W_hz) + self.b_z)
    R = torch.sigmoid(torch.matmul(X, self.W_xr) +
                  torch.matmul(H, self.W_hr) + self.b_r)
    H_tilde = torch.sigmoid(torch.matmul(X, self.W_xh) +
                  torch.matmul(R * H, self.W_hh) + self.b_h)
    H = Z * H + (1 - Z) * H_tilde
    outputs.append(H)
  return outputs, H

#### Training