In [1]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
%matplotlib inline

## Define the model




In [1]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, MultiheadAttention
from torch import einsum

def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

def exists(val):
    return val is not None

def linear_attn(q, k, v, kv_mask=None):
    dim = q.shape[-1]
    if exists(kv_mask):
        mask_value = max_neg_value(q)
        mask = kv_mask[:, None, :, None]
        k = k.masked_fill_(~mask, mask_value)
        v = v.masked_fill_(~mask, 0.)
        del mask
    q = q.softmax(dim=-1)
    k = k.softmax(dim=-2)
    q = q * dim ** -0.5

    context = einsum('bhnd,bhne->bhde', k, v)
    attn = einsum('bhnd,bhde->bhne', q, context)
    return attn.reshape(*q.shape)

class LinearAttentionLayer(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super(LinearAttentionLayer, self).__init__()
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.nhead = nhead
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, kv_mask=None):
        q = self.q_linear(query)
        k = self.k_linear(key)
        v = self.v_linear(value)

        q = q.view(query.size(0), -1, self.nhead, query.size(-1)//self.nhead).transpose(1, 2)
        k = k.view(key.size(0), -1, self.nhead, key.size(-1)//self.nhead).transpose(1, 2)
        v = v.view(value.size(0), -1, self.nhead, value.size(-1)//self.nhead).transpose(1, 2)

        attn_weights = linear_attn(q, k, v, kv_mask=kv_mask)
        attn_weights = self.dropout(attn_weights)

        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).contiguous().view(query.size(0), -1, query.size(-1))

        return output
class LinearAttentionMultiheadAttention(MultiheadAttention):
    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
        super(LinearAttentionMultiheadAttention, self).__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim)

        # 替换 self-attention 层
        self.attention = LinearAttentionLayer(embed_dim, num_heads, dropout=dropout)

class LinearAttentionTransformerEncoderLayer(TransformerEncoderLayer):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(LinearAttentionTransformerEncoderLayer, self).__init__(d_model, nhead, dim_feedforward, dropout, activation)

        # 替换 self_attn 层
        self.self_attn = LinearAttentionMultiheadAttention(d_model, nhead, dropout=dropout)


  from .autonotebook import tqdm as notebook_tqdm


In [48]:
import math
import os
from tempfile import TemporaryDirectory
from typing import Tuple

import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, linear_attention: bool = False,dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        if linear_attention:
            encoder_layers = LinearAttentionTransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        else:
            encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        """
        Arguments:
            src: Tensor, shape ``[seq_len, batch_size]``
            src_mask: Tensor, shape ``[seq_len, seq_len]``

        Returns:
            output Tensor of shape ``[seq_len, batch_size, ntoken]``
        """
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        if src_mask is None:
            """Generate a square causal mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
            """
            src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device)
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        return output

In [14]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

## Load and batch data




This tutorial uses ``torchtext`` to generate Wikitext-2 dataset.
To access torchtext datasets, please install torchdata following instructions at https://github.com/pytorch/data.
%%

In [4]:
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from transformers import GPT2Tokenizer, GPT2Model
from torchtext.datasets import PennTreebank
import pickle

def save_dataset(dataset, filename):
    with open(filename, 'wb') as f:
        pickle.dump(dataset, f)
def load_dataset(filename):
    with open(filename, 'rb') as f:
        dataset = pickle.load(f)
    return dataset

# # 载入预训练的 GPT-2 分词器
# model_name = "gpt2"
# tokenizer_bpe = GPT2Tokenizer.from_pretrained(model_name)
# train_iter, val_iter, test_iter = PennTreebank()

# def text_generator(raw_text_iter):
#     for item in raw_text_iter:
#         yield tokenizer_bpe.tokenize(item)
# vocab = build_vocab_from_iterator(text_generator(train_iter), specials=['<unk>'])
# vocab.set_default_index(vocab['<unk>'])
# print(len(vocab))

# 载入word tokenizer
train_iter = PennTreebank(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])


def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    # data = [torch.tensor(vocab(tokenizer_bpe.tokenize(item)), dtype=torch.long) for item in raw_text_iter]
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

# ``train_iter`` was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = PennTreebank()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into ``bsz`` separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Arguments:
        data: Tensor, shape ``[N]``
        bsz: int, batch size

    Returns:
        Tensor of shape ``[N // bsz, bsz]``
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape ``[seq_len, batch_size]``
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [11]:
# save_dataset(train_data, "./dataset/WikiText2/train_dataset.pkl")
# save_dataset(test_data, "./dataset/WikiText2/test_dataset.pkl")
# save_dataset(val_data, "./dataset/WikiText2/val_dataset.pkl")

In [5]:
print(train_data.shape)
print(len(vocab))

torch.Size([46220, 20])
9922


### Functions to generate input and target sequence




``get_batch()`` generates a pair of input-target sequences for
the transformer model. It subdivides the source data into chunks of
length ``bptt``. For the language modeling task, the model needs the
following words as ``Target``. For example, with a ``bptt`` value of 2,
we’d get the following two Variables for ``i`` = 0:

<img src="file://../_static/img/transformer_input_target.png">

It should be noted that the chunks are along dimension 0, consistent
with the ``S`` dimension in the Transformer model. The batch dimension
``N`` is along dimension 1.




In [6]:
bptt = 35
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape ``[full_seq_len, batch_size]``
        i: int

    Returns:
        tuple (data, target), where data has shape ``[seq_len, batch_size]`` and
        target has shape ``[seq_len * batch_size]``
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

## Initiate an instance




The model hyperparameters are defined below. The ``vocab`` size is
equal to the length of the vocab object.




In [49]:
ntokens = len(vocab)  # size of vocabulary
emsize = 200  # embedding dimension
d_hid = 200  # dimension of the feedforward network model in ``nn.TransformerEncoder``
nlayers = 2  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
nhead = 2  # number of heads in ``nn.MultiheadAttention``
dropout = 0.2  # dropout probability
linear_attention=True
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, linear_attention,dropout).to(device)

## Run the model




We use [CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)_
with the [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)_
(stochastic gradient descent) optimizer. The learning rate is initially set to
5.0 and follows a [StepLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html)_
schedule. During training, we use [nn.utils.clip_grad_norm\_](https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html)_
to prevent gradients from exploding.




In [50]:
import time

criterion = nn.CrossEntropyLoss()
lr = 5 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()

    num_batches = len(train_data) // bptt
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        output = model(data)
        output_flat = output.view(-1, ntokens)
        loss = criterion(output_flat, targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.8)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

def evaluate(model: nn.Module, eval_data: Tensor) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    with torch.no_grad():
        for i in range(0, eval_data.size(0) - 1, bptt):
            data, targets = get_batch(eval_data, i)
            seq_len = data.size(0)
            output = model(data)
            output_flat = output.view(-1, ntokens)
            total_loss += seq_len * criterion(output_flat, targets).item()
    return total_loss / (len(eval_data) - 1)

Loop over epochs. Save the model if the validation loss is the best
we've seen so far. Adjust the learning rate after each epoch.



In [47]:
best_val_loss = float('inf')
epochs = 3
print("Self_attention transformer training")
with TemporaryDirectory() as tempdir:
    tempdir = "./model/"
    name = "SA"
    best_model_params_path = os.path.join(tempdir, str(name)+"_best_model_params.pt")

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train(model)
        val_loss = evaluate(model, val_data)
        val_ppl = math.exp(val_loss)
        elapsed = time.time() - epoch_start_time
        print('-' * 89)
        print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_params_path)

        scheduler.step()
    model.load_state_dict(torch.load(best_model_params_path)) # load best model states

Self_attention transformer training
| epoch   1 |   200/ 1320 batches | lr 5.00 | ms/batch  4.90 | loss  7.88 | ppl  2644.63
| epoch   1 |   400/ 1320 batches | lr 5.00 | ms/batch  4.86 | loss  6.85 | ppl   948.60
| epoch   1 |   600/ 1320 batches | lr 5.00 | ms/batch  4.84 | loss  6.28 | ppl   531.14
| epoch   1 |   800/ 1320 batches | lr 5.00 | ms/batch  4.89 | loss  5.99 | ppl   400.63
| epoch   1 |  1000/ 1320 batches | lr 5.00 | ms/batch  4.86 | loss  5.82 | ppl   337.99
| epoch   1 |  1200/ 1320 batches | lr 5.00 | ms/batch  4.85 | loss  5.70 | ppl   298.23
-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  6.70s | valid loss  5.66 | valid ppl   285.89
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 1320 batches | lr 4.75 | ms/batch  4.87 | loss  5.61 | ppl   272.83
| epoch   2 |   400/ 1320 batches | lr 4.75 | ms/batch  5.19 | loss  5.57 | ppl   263.7

In [51]:
best_val_loss = float('inf')
epochs = 3
print("Linear_attention transformer training")
with TemporaryDirectory() as tempdir:
    tempdir = "./model/"
    name = "LA"
    best_model_params_path = os.path.join(tempdir, str(name)+"_best_model_params.pt")

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train(model)
        val_loss = evaluate(model, val_data)
        val_ppl = math.exp(val_loss)
        elapsed = time.time() - epoch_start_time
        print('-' * 89)
        print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
        print('-' * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_params_path)

        scheduler.step()
    model.load_state_dict(torch.load(best_model_params_path)) # load best model states

Linear_attention transformer training
| epoch   1 |   200/ 1320 batches | lr 5.00 | ms/batch  5.34 | loss  7.97 | ppl  2895.92
| epoch   1 |   400/ 1320 batches | lr 5.00 | ms/batch  4.94 | loss  6.84 | ppl   931.78
| epoch   1 |   600/ 1320 batches | lr 5.00 | ms/batch  4.94 | loss  6.61 | ppl   740.60
| epoch   1 |   800/ 1320 batches | lr 5.00 | ms/batch  4.96 | loss  6.50 | ppl   662.23
| epoch   1 |  1000/ 1320 batches | lr 5.00 | ms/batch  4.91 | loss  6.24 | ppl   513.88
| epoch   1 |  1200/ 1320 batches | lr 5.00 | ms/batch  4.93 | loss  6.00 | ppl   402.50
-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  6.88s | valid loss  5.89 | valid ppl   360.99
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 1320 batches | lr 4.75 | ms/batch  4.98 | loss  5.83 | ppl   339.76
| epoch   2 |   400/ 1320 batches | lr 4.75 | ms/batch  4.90 | loss  5.76 | ppl   317

## Evaluate the best model on the test dataset




In [54]:
linear_attention = False
model_self = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, linear_attention,dropout).to(device)
model_self.load_state_dict(torch.load("./model/SA_best_model_params.pt"))
test_loss = evaluate(model_self, test_data)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| PTB|self attention|word| test loss {test_loss:5.2f} | '
      f'test ppl {test_ppl:8.2f}')

linear_attention = True
model_linear = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, linear_attention,dropout).to(device)
model_linear.load_state_dict(torch.load("./model/LA_best_model_params.pt"))
test_loss_linear = evaluate(model_linear, test_data)
test_ppl_linear = math.exp(test_loss_linear)


print('=' * 89)
print(f'| PTB|linear attention|word| test loss {test_loss_linear:5.2f} | '
      f'test ppl {test_ppl_linear:8.2f}')
print('=' * 89)

| PTB|self attention|word| test loss  5.35 | test ppl   210.35
| PTB|linear attention|word| test loss  5.45 | test ppl   232.69


In [35]:
test_loss = evaluate(model, test_data)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training |PTB|self attention|BPE| test loss {test_loss:5.2f} | '
      f'test ppl {test_ppl:8.2f}')
print('=' * 89)

# test_loss_linear = evaluate(model_linear, test_data)
# test_ppl_linear = math.exp(test_loss_linear)


# print('=' * 89)
# print(f'| End of training |PTB|linear attention|BPE| test loss {test_loss_linear:5.2f} | '
#       f'test ppl {test_ppl_linear:8.2f}')
# print('=' * 89)

| End of training |PTB|self attention|BPE| test loss  9.56 | test ppl 14244.27


In [None]:
# test_loss = evaluate(model, test_data)
# test_ppl = math.exp(test_loss)
# print('=' * 89)
# print(f'| End of training |WiKi-Text2|BPE| test loss {test_loss:5.2f} | '
#       f'test ppl {test_ppl:8.2f}')
# print('=' * 89)

| End of training |WiKi-Text2|BPE| test loss  4.65 | test ppl   104.81


In [None]:
# test_loss = evaluate(model, test_data)
# test_ppl = math.exp(test_loss)
# print('=' * 89)
# print(f'| End of training |WiKi-Text2|Word| test loss {test_loss:5.2f} | '
#       f'test ppl {test_ppl:8.2f}')
# print('=' * 89)

| End of training |WiKi-Text2|Word| test loss  5.50 | test ppl   245.35


In [None]:
# test_loss = evaluate(model, test_data)
# test_ppl = math.exp(test_loss)
# print('=' * 89)
# print(f'| End of training |PTB|Word| test loss {test_loss:5.2f} | '
#       f'test ppl {test_ppl:8.2f}')
# print('=' * 89)

| End of training |PTB|Word| test loss  5.23 | test ppl   187.41


In [None]:
# test_loss = evaluate(model, test_data)
# test_ppl = math.exp(test_loss)
# print('=' * 89)
# print(f'| End of training |PTB|BPE| test loss {test_loss:5.2f} | '
#       f'test ppl {test_ppl:8.2f}')
# print('=' * 89)

| End of training |PTB|BPE| test loss  4.53 | test ppl    93.14


In [20]:
test_loss = evaluate(model, test_data)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training |PTB|linear attention|BPE| test loss {test_loss:5.2f} | '
      f'test ppl {test_ppl:8.2f}')
print('=' * 89)

| End of training |PTB|linear attention|BPE| test loss  4.52 | test ppl    92.08
