<a href="https://colab.research.google.com/github/SunshineGreeny/Dive-into-deep-learning-Pytorch/blob/colab-experiments/chapter_attention-mechanisms-and-transformers/bahdanau-attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# The Bahdanau Attention Mechanism



In [51]:
import math
import torch
from torch import nn
from torch.nn import functional as F
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader,Dataset
from collections import Counter
import re
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

Tools

In [59]:
def try_gpu(i=0):
    """返回可用的GPU，否则返回CPU"""
    if torch.cuda.device_count() >= i+1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')


class MTFraEng(torch.utils.data.Dataset):
    def __init__(self, src_texts, tgt_texts, num_steps=10, reserved_tokens=None):
        self.src_tokens = [self._preprocess(t) for t in src_texts]
        self.tgt_tokens = [self._preprocess(t) for t in tgt_texts]
        self.src_vocab = self._build_vocab(self.src_tokens, reserved_tokens)
        self.tgt_vocab = self._build_vocab(self.tgt_tokens, reserved_tokens)
        self.num_steps = num_steps
        self.src_array = self._build_array(self.src_tokens, self.src_vocab, num_steps)
        self.tgt_array = self._build_array(self.tgt_tokens, self.tgt_vocab, num_steps)

    def __len__(self):
        return len(self.src_array)

    def __getitem__(self, idx):
        return self.src_array[idx], self.tgt_array[idx]

    @staticmethod
    def _preprocess(text):
        text = text.lower().strip()
        text = re.sub(r"([.!?])", r" \1", text)
        text = re.sub(r"[^a-zA-Z.!?]+", r" ", text)
        return text.split()

    @staticmethod
    def _build_vocab(tokens_list, reserved_tokens=None, min_freq=1):
        counter = Counter([tk for line in tokens_list for tk in line])
        token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        idx_to_token = ['<pad>', '<bos>', '<eos>'] + (reserved_tokens or [])
        idx_to_token += [tk for tk, freq in token_freqs if freq >= min_freq and tk not in idx_to_token]
        token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}
        return {"idx_to_token": idx_to_token, "token_to_idx": token_to_idx}

    @staticmethod
    def _build_array(lines, vocab, num_steps):
        bos, eos, pad = vocab["token_to_idx"]["<bos>"], vocab["token_to_idx"]["<eos>"], vocab["token_to_idx"]["<pad>"]
        array = []
        for line in lines:
            arr = [bos] + [vocab["token_to_idx"].get(w, pad) for w in line] + [eos]
            if len(arr) < num_steps:
                arr += [pad] * (num_steps - len(arr))
            else:
                arr = arr[:num_steps-1] + [eos]  # 保证最后是 eos
            array.append(arr)
        return torch.tensor(array)


def bleu(preq_seq,label_seq,k=2):
  return sentence_bleu([label_seq.split()],
             preq_seq.split(),
             smoothing_function=SmoothingFunction().method1,
             weights=tuple[1/k]*k)


class AddictiveAttention(nn.Module):
  def __init__(self,num_hiddens,dropout=0.):
    super().__init__()
    self.W_q=nn.Linear(num_hiddens,num_hiddens,bias=False)
    self.W_k=nn.Linear(num_hiddens,num_hiddens,bias=False)
    self.W_v=nn.Linear(num_hiddens,1,bias=False)
    self.dropout=nn.Dropout(dropout)
    self.attention_weights=None

  def forward(self,queries,keys,values,valid_lens=None):
    queries,keys=self.W_q(queries),self.W_k(keys)
    features=queries.unsqueeze(2)+keys.unsqueeze(1)
    scores=self.W_v(torch.tanh(features)).squeeze(-1)
    if valid_lens is not None:
      mask=torch.arange(keys.shape[1],device=keys.device)[None,:]>=valid_lens[:,None]
      scores=scores.masked_fill(mask[:,None],-1e6)
    self.attention_weights=F.softmax(scores,dim=-1)
    return torch.bmm(self.dropout(self.attention_weights),values)


class Decoder(nn.Module):
  def init_state(self,enc_outputs,*args):
    raise NotImplementedError
  def forward(self,X,state):
    raise NotImplementedError


class Seq2SeqEncoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.LSTM(embed_size, num_hiddens, num_layers,
                           dropout=dropout, batch_first=True)

    def forward(self, X):
        X = self.embedding(X)
        output, state = self.rnn(X)
        return output, state  # output: (B, T, H), state: (h_n, c_n)


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, tgt):
        enc_outputs = self.encoder(src)
        dec_state = self.decoder.init_state(enc_outputs)
        dec_input = tgt[:, :-1]
        logits, _ = self.decoder(dec_input, dec_state, enc_outputs)
        return logits


def init_seq2seq(src_vocab_size, tgt_vocab_size, embed_size=256, num_hiddens=256, num_layers=2, dropout=0.1):
    encoder = Seq2SeqEncoder(src_vocab_size, embed_size, num_hiddens, num_layers, dropout)
    decoder = Seq2SeqAttentionDecoder(tgt_vocab_size, embed_size, num_hiddens, num_layers, dropout)
    return Seq2Seq(encoder, decoder)


def check_shape(tensor, expected_shape):
    assert tensor.shape == expected_shape, \
        f'Expected shape: {expected_shape}, got: {tensor.shape}'
    return tensor


def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """Show heatmaps of matrices."""
    num_rows,num_cols=matrices.shape[0],matrices.shape[1]
    # 创建子图网络
    fig,axes=plt.subplots(num_rows,num_cols,figsize=figsize,
                          sharex=True,sharey=True,squeeze=False)
    # 遍历所有矩阵并绘制热力图
    for i in range(num_rows):
        for j in range(num_cols):
            ax=axes[i,j]
            # 将张量转换为numpy数组并分离计算图
            matrix=matrices[i,j].detach().numpy()

            # 使用热力图显示矩阵
            pcm=ax.imshow(matrix,cmap='Reds')

            # 设置坐标轴标签
            if i==num_rows-1:#最后一行显示x轴标签
                ax.set_xlabel(xlabel)
            if j==0:#第一列显示y轴标签
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])

            # 隐藏刻度线
            ax.xaxis.set_ticks_position('none')
            ax.yaxis.set_ticks_position('none')

    # Add a colorbar
    fig.colorbar(pcm, ax=axes.ravel().tolist())
    plt.show()

class Trainer:
    def __init__(self, model, lr=0.005, device=None):
        self.device = torch.device('cpu') # Explicitly set device to CPU
        self.model = model.to(self.device)
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=0)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def fit(self, train_iter, num_epochs=5):
        for epoch in range(num_epochs):
            self.model.train()
            total_loss = 0
            for src, tgt in train_iter:
                src, tgt = src.to(self.device), tgt.to(self.device)
                logits = self.model(src, tgt)
                y = tgt[:,1:].reshape(-1)
                l = self.loss_fn(logits.reshape(-1, logits.shape[-1]), y)
                self.optimizer.zero_grad()
                l.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
                self.optimizer.step()
                total_loss += l.item()
            print(f"epoch {epoch+1}, loss {total_loss/len(train_iter):.3f}")

The base interface for decoders with attention

In [53]:
class AttentionDecoder(Decoder):
    """The base attention-based decoder interface."""
    def __init__(self):
        super().__init__()

    @property
    def attention_weights(self):
        raise NotImplementedError

Implement the RNN decoder

In [54]:
class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0):
        super().__init__()
        self.attention = AddictiveAttention(num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(
            embed_size + num_hiddens, num_hiddens, num_layers,
            dropout=dropout)
        self.dense = nn.LazyLinear(vocab_size)
        # The original code had init_seq2seq here, which is not defined for the decoder.
        # Removing this line.

    def init_state(self, enc_outputs, enc_hidden_state, enc_valid_lens):
        return (enc_outputs.permute(1, 0, 2), enc_hidden_state, enc_valid_lens)

    def forward(self, X, state):
        enc_outputs, hidden_state, enc_valid_lens = state
        X = self.embedding(X).permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        for x in X:
            query = torch.unsqueeze(hidden_state[-1], dim=1)
            context = self.attention(
                query, enc_outputs, enc_outputs, enc_valid_lens)
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        outputs = self.dense(torch.cat(outputs, dim=0))
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,
                                          enc_valid_lens]

    @property
    def attention_weights(self):
        return self._attention_weights

Test the implemented
decoder

In [60]:
vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 7
encoder = Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)
decoder = Seq2SeqAttentionDecoder(vocab_size, embed_size, num_hiddens,
                                  num_layers)
X = torch.zeros((batch_size, num_steps), dtype=torch.long)
enc_outputs, enc_hidden_state = encoder(X)
# Pass enc_outputs, enc_hidden_state, and enc_valid_lens to init_state
state = decoder.init_state(enc_outputs, enc_hidden_state, torch.tensor([num_steps] * batch_size))
output, state = decoder(X, state)
check_shape(output, (batch_size, num_steps, vocab_size))
check_shape(state[0], (batch_size, num_steps, num_hiddens))
check_shape(state[1][0], (batch_size, num_hiddens))

RuntimeError: The size of tensor a (4) must match the size of tensor b (7) at non-singleton dimension 1

Training

In [61]:
# Placeholder data - Replace with your actual source and target sentences
src_texts = ["I am a student.", "He likes apples.", "She is reading."]
tgt_texts = ["Je suis étudiant.", "Il aime les pommes.", "Elle lit."]
num_steps = 10
batch_size = 2

dataset = MTFraEng(src_texts, tgt_texts, num_steps=num_steps)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
embed_size, num_hiddens, num_layers, dropout = 256, 256, 2, 0.2
encoder = Seq2SeqEncoder(
    len(data.src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
    len(data.tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
model = Seq2Seq(encoder, decoder) # Removed tgt_pad and lr as they are not in Seq2Seq constructor
trainer = Trainer(model, lr=0.005) # Pass lr to Trainer
trainer.fit(train_loader, num_epochs=30) # Changed data to train_loader and added num_epochs

IndexError: index out of range in self

Translate a few English sentences

In [56]:
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
preds, _ = model.predict_step(
    data.build(engs, fras), try_gpu(), data.num_steps)
for en, fr, p in zip(engs, fras, preds):
    translation = []
    for token in data.tgt_vocab.to_tokens(p):
        if token == '<eos>':
            break
        translation.append(token)
    print(f'{en} => {translation}, bleu,'
          f'{bleu(" ".join(translation), fr, k=2):.3f}')

AttributeError: 'Seq2Seq' object has no attribute 'predict_step'

Visualize the attention weights

In [None]:
_, dec_attention_weights = model.predict_step(
    data.build([engs[-1]], [fras[-1]]), try_gpu(), data.num_steps, True)
attention_weights = torch.cat(
    [step[0][0][0] for step in dec_attention_weights], 0)
attention_weights = attention_weights.reshape((1, 1, -1, data.num_steps))

show_heatmaps(
    attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),
    xlabel='Key positions', ylabel='Query positions')