# Final Project - Transformer Implementation
**Authors**:
* Yuxuan Sun: <yuxuan_eric_sun@outlook.com>
* Sergey: <seriy.karp2@gmail.com>
* Haitao Gao: <haitaogao423@gmail.com>

**Code Repository**: <https://github.com/Erostrate9/needle>


## Introduction
To sum up what we learned during *10-714: Deep Learning Systems*, we've implemented the Transformer architecture and its corresponding modules with our self-made *needle*.

The overall goal of our *Final Project* is to implement the trainable Transformer architecture [1], which can be divided into some ingredients — Multi-Head Attention, Self-Attention and Positional Encoding, and The Transformer Architecture (Positionwise Feed-Forward Networks, Residual Connection and Layer Normalization, Transformer Encoder Block & Encoder, Transformer Decoder Block & Decoder, and Encoder-Decoder Seq2Seq model.)

In this project, in order to simplify the verification of numerical correctness, we use the implementation of d2l.ai [2] for reference, and its PyTorch implementation [3] has been cited for comparison.

[1]: Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. "Attention is all you need." Advances in neural information processing systems 30, 2017.
[2]: Zhang, Aston, Zachary C. Lipton, Mu Li, and Alexander J. Smola. "Dive into deep learning." arXiv preprint arXiv:2106.11342, 2021.
[3]: Zhang, Aston and Lipton, Zachary C. and Li, Mu and Smola, Alexander J., “Releases v0.17.5 · D2L-ai/D2L-en-pytorch,” GitHub, 27-May-2022. [Online]. Available: <https://github.com/d2l-ai/d2l-en/releases/download/v0.17.5/d2l-en-pytorch.pdf>


## Multi-Head Attention
![Multi-head attention, where multiple heads are concatenated then linearly transformed.[2, Fig. 11.5.1]](https://d2l.ai/_images/multi-head-attention.svg)
### Model
As Prof. Zico Kolter has introduced the mechanism of self-attention and Transformer in Lecture 20 & 21, we'll briefly describe the theory part in this report and focus on its implementation.
According to [4], the practical implementation of attention normally used is what we called *multi-head attention*, which means that we run the self-attention mechanism of different subsets of the $K$, $Q$, $V$ terms, then concatenate them together.  Formally, we'll partition these terms as
\begin{equation}
K = \begin{bmatrix} K_1 & K_2 & \cdots & K_{\mathrm{heads}} \end{bmatrix}
\end{equation}
(and similarly for $Q$ and $V$.

Then will form the self attention outputs
\begin{equation}
Y_i = \mathrm{softmax}\left(\frac{K_iQ_i^T}{\sqrt{d/\mathrm{heads}}}\right)V_i
\end{equation}
and then form the final ouput
\begin{equation}
Y = \begin{bmatrix} Y_1 & Y_2 & \cdots & Y_{\mathrm{heads}} \end{bmatrix} W_o.
\end{equation}

[4]: Z. Kolter, “Public_notebooks/transformer_implementation.ipynb at main · dlsyscourse/public_notebooks,” Deep Learning Systems 21 - Transformers + Attention Implementation, 15-Nov-2022. [Online]. Available: https://github.com/dlsyscourse/public_notebooks/blob/main/transformer_implementation.ipynb.

### Implementation
First, we
For each head of the multi-head attention, we choose the scaled dot-product attention, where a masked softmax operation is used to output a probability distribution as attention weights.
```python
class DotProductAttention(Module):
    """Scaled dot product attention."""

    def __init__(self, dropout):
        super().__init__()
        self.dropout = Dropout(dropout)

    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    def masked_softmax(self, X, valid_lens):
        # valid_lens: numpy array
        def _sequence_mask(X: Tensor, valid_lens, value=0):
            # X: n * d
            maxlen = X.shape[-1]
            mask = (torch.arange((maxlen), dtype=torch.float32)[None, :].numpy() < valid_lens[:, None])
            mask_mul = mask.astype(np.float32)
            mask_add = (~mask).astype(np.float32) * value
            mask_mul = Tensor(mask_mul, device=X.device, dtype=X.dtype, requires_grad=False)
            mask_add = Tensor(mask_add, device=X.device, dtype=X.dtype, requires_grad=False)
            return X * mask_mul + mask_add

        if valid_lens is None:
            return ops.softmax(X)
        else:
            shape = X.shape
            if len(valid_lens.shape) == 1:
                assert valid_lens.shape[0] == X.shape[0]
                valid_lens = valid_lens.repeat(X.shape[1])
            else:
                valid_lens = valid_lens.reshape(-1)
            # On the last axis, replace masked elements with a very large negative
            # value, whose exponentiation outputs 0
            X = _sequence_mask(X.reshape((prod(shape[:-1]), shape[-1])), valid_lens, value=-1e6)
            return ops.softmax(X.reshape(shape))

    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.transpose(1, 2)
        scores = ops.bmm(queries, keys.transpose((1, 2))) / math.sqrt(d)
        self.attention_weights = self.masked_softmax(scores, valid_lens)
        return ops.bmm(self.dropout(self.attention_weights), values)
```
```python
class MultiHeadAttention(Module):
    """Multi-head attention."""

    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, device=None, dtype="float32"):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = Linear(query_size, num_hiddens, bias=bias, device=device, dtype=dtype)
        self.W_k = Linear(key_size, num_hiddens, bias=bias, device=device, dtype=dtype)
        self.W_v = Linear(value_size, num_hiddens, bias=bias, device=device, dtype=dtype)
        self.W_o = Linear(num_hiddens, num_hiddens, bias=bias, device=device, dtype=dtype)

    def forward(self, queries, keys, values, valid_lens):
        # Shape of queries, keys, or values:
        # (batch_size, no. of queries or key-value pairs, num_hiddens)
        # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
        # After transposing, shape of output queries, keys, or values:
        # (batch_size * num_heads, no. of queries or key-value pairs,
        # num_hiddens / num_heads)
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))
        if valid_lens is not None:
            # On axis 0, copy the first item (scalar or vector) for num_heads
            # times, then copy the next item, and so on
            valid_lens = valid_lens.repeat(repeats=self.num_heads, axis=0)

        # Shape of output: (batch_size * num_heads, no. of queries,
        # num_hiddens / num_heads)
        output = self.attention(queries, keys, values, valid_lens)
        # Shape of output_concat: (batch_size, no. of queries, num_hiddens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)

    def transpose_qkv(self, X):
        """Transposition for parallel computation of multiple attention heads."""
        # Shape of input X: (batch_size, no. of queries or key-value pairs,
        # num_hiddens). Shape of output X: (batch_size, no. of queries or
        # key-value pairs, num_heads, num_hiddens / num_heads)
        X = X.reshape((X.shape[0], X.shape[1], self.num_heads, -1))
        # Shape of output X: (batch_size, num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        X = X.permute((0, 2, 1, 3))
        # Shape of output: (batch_size * num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        X = X.reshape((-1, X.shape[2], X.shape[3]))
        return X

    def transpose_output(self, X):
        """Reverse the operation of transpose_qkv."""
        X = X.reshape((-1, self.num_heads, X.shape[1], X.shape[2]))
        X = X.permute((0, 2, 1, 3))
        return X.reshape((X.shape[0], X.shape[1], -1))
```

In [15]:
import torch
import numpy as np
maxlen = 40
valid_lens = np.array([3,2])
a = torch.arange((maxlen), dtype=torch.float32)[None, :]
a_ = np.arange(maxlen, dtype=np.float32)[None, :]
print("a:", np.linalg.norm(a.numpy() - a_))
x = torch.normal(3,4,(3,4,5))
y = x.mean(dim=1)
y_ = x.sum(dim=1) / x.shape[1]
print("mean:", np.linalg.norm(y.numpy() - y_.numpy()))
# mask = (torch.arange((maxlen), dtype=torch.float32)[None, :].numpy() < valid_lens[:, None])

a: 0.0
mean: 0.0


In [66]:
import numpy as np
import torch.nn.functional as F
def one_hot(n, i):
    return np.eye(n)[i]

target = torch.empty((64, 10, 15, 4), dtype=torch.long).random_(201)

oh = F.one_hot(target)
oh_ = one_hot(201, target.detach().numpy())
print(np.linalg.norm(oh.detach().numpy()-oh_))

0.0


In [68]:
y_one_hot = np.random.randn(64, 10, 15, 201)

axes = ((0, len(y_one_hot.shape)-1, *tuple(range(len(y_one_hot.shape)))[1:-1])) if len(y_one_hot.shape)>1 else (0,)
axes

(0, 3, 1, 2)