In [12]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [23]:
import torch 
from torch import nn 
 
class AttentionUpdateGateGRUCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        # (Wu|Wr|Wn)
        self.weight_ih = nn.Parameter(
            torch.Tensor(3 * hidden_size, input_size))
        # (Uu|Ur|Un)
        self.weight_hh = nn.Parameter(
            torch.Tensor(3 * hidden_size, hidden_size))
        if bias:
            # (b_iu|b_ir|b_in)
            self.bias_ih = nn.Parameter(torch.Tensor(3 * hidden_size))
            # (b_hu|b_hr|b_hn)
            self.bias_hh = nn.Parameter(torch.Tensor(3 * hidden_size))
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)
        self.reset_parameters()
 
    def reset_parameters(self):
        stdv = 1.0 / (self.hidden_size)**0.5
        for weight in self.parameters():
            nn.init.uniform_(weight, -stdv, stdv)
            
    def forward(self, x, hx, att_score):
        gi = F.linear(x, self.weight_ih, self.bias_ih)
        gh = F.linear(hx, self.weight_hh, self.bias_hh)
        i_r, i_u, i_n = gi.chunk(3, 1)
        h_r, h_u, h_n = gh.chunk(3, 1)
 
        resetgate = torch.sigmoid(i_r + h_r)
        updategate = torch.sigmoid(i_u + h_u)
        newgate = torch.tanh(i_n + resetgate * h_n)
 
        updategate = att_score.view(-1, 1) * updategate
        hy = (1-updategate)*hx +  updategate*newgate
 
        return hy

In [None]:
class AttentionUpdateGateGRUCell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        # (Wxr|Wxz|Wxh)
        self.weight_xrzh = nn.Parameter(
            torch.ones(3 * hidden_size, input_size,dtype=torch.float32))
        # (Hxr|Hxz|Hxh)
        self.weight_hrzh = nn.Parameter(
            torch.ones(3 * hidden_size, hidden_size,dtype=torch.float32))
        if bias:
            # (b)
            self.bias_r = nn.Parameter(torch.zero(hidden_size))
            self.bias_z = nn.Parameter(torch.zero(hidden_size))
            self.bias_h = nn.Parameter(torch.zero(hidden_size))
        else:
            self.register_parameter('bias_rzh', None)
        self.reset_parameters()
 
    def reset_parameters(self):
        # 常用初始化策略保证var(w)=1/hidden_size
        stdv = 1.0 / self.hidden_size ** 0.5
        for weight in self.parameters():
            nn.init.uniform_(weight, -stdv, stdv)


    def forward(self, x, hx, att_score):
        X_r,X_z,X_h = F.linear(x, self.weight_xrzh).chunk(3, 1)
        H_r,H_z,H_h = F.linear(hx, self.weight_hrzh).chunk(3, 1)

 
        reset_gate = torch.sigmoid(X_r + H_r+self.bias_r)
        update_gate = torch.sigmoid(X_z + H_z+self.bias_z)
        hidden_gate_pre = torch.tanh(X_h + resetgate * h_n)
 
        updategate = att_score.view(-1, 1) * updategate
        hy = (1-updategate)*hx +  updategate*newgate
 
        return hy