In [2]:
import torch
import torch.nn as nn
import math
import numpy

In [4]:
class MultiHeadedAttention(nn.Module):
    """Multi-Head Attention layer.

    Args:
        n_head (int): The number of heads.
        n_feat (int): The number of features.
        dropout_rate (float): Dropout rate.

    """

    def __init__(self, n_head, n_feat, dropout_rate):
        """Construct an MultiHeadedAttention object."""
        super(MultiHeadedAttention, self).__init__()
        assert n_feat % n_head == 0
        # We assume d_v always equals d_k
        self.d_k = n_feat // n_head
        self.h = n_head
        self.linear_q = nn.Linear(n_feat, n_feat)
        self.linear_k = nn.Linear(n_feat, n_feat)
        self.linear_v = nn.Linear(n_feat, n_feat)
        self.linear_out = nn.Linear(n_feat, n_feat)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward_qkv(self, query, key, value):
        """Transform query, key and value.

        Args:
            query (torch.Tensor): Query tensor (#batch, time1, size).
            key (torch.Tensor): Key tensor (#batch, time2, size).
            value (torch.Tensor): Value tensor (#batch, time2, size).

        Returns:
            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).

        """
        n_batch = query.size(0)
        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
        v = v.transpose(1, 2)  # (batch, head, time2, d_k)

        return q, k, v

    def forward_attention(self, value, scores, mask, survive_head_idx=[-1]):
        """Compute attention context vector.

        Args:
            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).

        Returns:
            torch.Tensor: Transformed value (#batch, time1, d_model)
                weighted by the attention score (#batch, time1, time2).

        """
        n_batch = value.size(0)
        if mask is not None:
            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
            min_value = float(
                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
            )
            scores = scores.masked_fill(mask, min_value)
            self.attn = torch.softmax(scores, dim=-1).masked_fill(
                mask, 0.0
            )  # (batch, head, time1, time2)
        else:
            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)

        if survive_head_idx[0] is not -1:
            del_attn = torch.empty_like(self.attn)
            del_attn[:,survive_head_idx,:].copy_(self.attn[:,survive_head_idx,:])
            self.attn = del_attn

        p_attn = self.dropout(self.attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
        )  # (batch, time1, d_model)

        return self.linear_out(x), self.attn  # (batch, time1, d_model)

    def forward(self, query, key, value, mask, survive_head_idx=[-1]):
        """Compute scaled dot product attention.

        Args:
            query (torch.Tensor): Query tensor (#batch, time1, size).
            key (torch.Tensor): Key tensor (#batch, time2, size).
            value (torch.Tensor): Value tensor (#batch, time2, size).
            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
                (#batch, time1, time2).

        Returns:
            torch.Tensor: Output tensor (#batch, time1, d_model).

        """
        q, k, v = self.forward_qkv(query, key, value)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        return self.forward_attention(v, scores, mask, survive_head_idx)

In [6]:
class Hook():
    def __init__(self, module, layer_idx, survive_head_idx, module_type='encoder', attn_type='self_attn', backward=False):
        if backward==False:
            self.hook = module.register_forward_pre_hook(self.hook_fn)
        else:
            self.hook = module.register_backward_hook(self.hook_fn)

        self.layer_idx = layer_idx
        self.survive_head_idx = survive_head_idx
        self.name = str(self.layer_idx) + '_' + str(self.survive_head_idx)


    def hook_fn(self, module, input):
        query, key, value, mask, head_idx = input
        head_idx[0] = self.survive_head_idx
        return query, key, value, mask, head_idx
        
    def close(self):
        self.hook.remove()

In [8]:
test_attn = MultiHeadedAttention(8, 512, 0.1).cuda()

In [10]:
print(test_attn)

MultiHeadedAttention(
  (linear_q): Linear(in_features=512, out_features=512, bias=True)
  (linear_k): Linear(in_features=512, out_features=512, bias=True)
  (linear_v): Linear(in_features=512, out_features=512, bias=True)
  (linear_out): Linear(in_features=512, out_features=512, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)


In [12]:
temp_q = torch.rand(1, 99, 512).cuda()
temp_k = torch.rand(1, 99, 512).cuda()
temp_v = torch.rand(1, 99, 512).cuda()
temp_mask = torch.ones(1, 1, 99).bool().cuda()

In [14]:
hook_fn = Hook(module=test_attn, layer_idx=0, survive_head_idx=2)

In [16]:
out_after_delete = test_attn(temp_q, temp_k, temp_v, temp_mask, [-1])

In [19]:
hook_fn.input[4]

[-1]