# 11.5. Multi-Head Attention

## AUTOR: Isaac Reyes

In practice, given the same set of queries, keys, and values we may want our model to combine knowledge from different behaviors of the same attention mechanism, such as capturing dependencies of various ranges (e.g., shorter-range vs. longer-range) within a sequence. Thus, it may be beneficial to allow our attention mechanism to jointly use different representation subspaces of queries, keys, and values.

To this end, instead of performing a single attention pooling, queries, keys, and values can be transformed with 
 independently learned linear projections. Then these 
 projected queries, keys, and values are fed into attention pooling in parallel. In the end, 
 attention-pooling outputs are concatenated and transformed with another learned linear projection to produce the final output. This design is called multi-head attention,

In [1]:
#Librerias:
use strict;
use warnings;
use Data::Dump qw(dump);
use d2l;
IPerl->load_plugin('Chart::Plotly');

### 11.5.1. Model

Based on this design, each head may attend to different parts of the input. More sophisticated functions than the simple weighted average can be expressed.

### 11.5.2. Implementation

In our implementation, we choose the scaled dot product attention for each head of the multi-head attention. To avoid significant growth of computational cost and parametrization cost, we set 
. Note that 
 heads can be computed in parallel if we set the number of outputs of linear transformations for the query, key, and value to 
. In the following implementation, 
 is specified via the argument num_hiddens.

In [2]:
package MultiHeadAttention {
    use base("d2l::Module");
#@Save
    sub new {
        my ($class, %args) = (
            shift, 
            d2l->get_arguments(
                num_hiddens => undef,
                num_heads => undef,
                dropout => undef,
                use_bias => 0,
                \@_
            )
        );
        my $self = $class->SUPER::new(%args);
        $self->{num_heads} = $args{num_heads};
        $self->{attention} = new d2l::DotProductAttention($args{dropout});
        $self->{W_q} = mx->gluon->nn->Dense($args{num_hiddens}, use_bias => $args{use_bias}, flatten => 0);
        $self->{W_k} = mx->gluon->nn->Dense($args{num_hiddens}, use_bias => $args{use_bias}, flatten => 0);
        $self->{W_v} = mx->gluon->nn->Dense($args{num_hiddens}, use_bias => $args{use_bias}, flatten => 0);
        $self->{W_o} = mx->gluon->nn->Dense($args{num_hiddens}, use_bias => $args{use_bias}, flatten => 0);
        
        map { $self->register_child($self->{$_}) } ('attention', 'W_q', 'W_k', 'W_v', 'W_o');
        return bless($self, $class);
    }

    sub forward {
        my ($self, $queries, $keys, $values, $valid_lens) = @_;
        $queries = $self->transpose_qkv($self->{W_q}->forward($queries));
        $keys = $self->transpose_qkv($self->{W_k}->forward($keys));
        $values = $self->transpose_qkv($self->{W_v}->forward($values));
        
        if (defined $valid_lens) {
            $valid_lens = $valid_lens->repeat($self->{num_heads}, axis => 0);
        }
        
        my $output = $self->{attention}->forward($queries, $keys, $values, $valid_lens);
        my $output_concat = $self->transpose_output($output);
        return $self->{W_o}->forward($output_concat);
    }

    1;
}

1

To allow parallel computation

In [3]:
my $transpose_qkv = sub {
    my ($self, $X) = @_;
    $X = $X->reshape([$X->shape->[0], $X->shape->[1], $self->{num_heads}, -1]);
    $X = $X->transpose([0, 2, 1, 3]);
    return $X->reshape([-1, $X->shape->[2], $X->shape->[3]]);
};
d2l->add_to_class('MultiHeadAttention', 'transpose_qkv', $transpose_qkv);

my $transpose_output = sub {
    my ($self, $X) = @_;
    $X = $X->reshape([-1, $self->{num_heads}, $X->shape->[1], $X->shape->[2]]);
    $X = $X->transpose([0, 2, 1, 3]);
    return $X->reshape([$X->shape->[0], $X->shape->[1], -1]);
};
d2l->add_to_class('MultiHeadAttention', 'transpose_output', $transpose_output);

*MultiHeadAttention::transpose_output

To test the implementation

In [4]:
my ($num_hiddens, $num_heads) = (100, 5);
my $attention = new MultiHeadAttention($num_hiddens, $num_heads, 0.5);
$attention->initialize();
my ($batch_size, $num_queries, $num_kvpairs) = (2, 4, 6);
my $valid_lens = mx->nd->array([3, 2]);
my $X = mx->nd->ones([$batch_size, $num_queries, $num_hiddens]);
my $Y = mx->nd->ones([$batch_size, $num_kvpairs, $num_hiddens]);
d2l->check_shape($attention->forward($X, $Y, $Y, $valid_lens), [$batch_size, $num_queries, $num_hiddens]);

1

No hacer:
#### 11.5.3. Summary
Multi-head attention combines knowledge of the same attention pooling via different representation subspaces of queries, keys, and values. To compute multiple heads of multi-head attention in parallel, proper tensor manipulation is needed.

##### 11.5.4. Exercises
Visualize attention weights of multiple heads in this experiment.

Suppose that we have a trained model based on multi-head attention and we want to prune less important attention heads to increase the prediction speed. How can we design experiments to measure the importance of an attention head?