<a href="https://colab.research.google.com/github/SunshineGreeny/Dive-into-deep-learning-Pytorch/blob/main/chapter_attention-mechanisms-and-transformers/bahdanau-attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# The Bahdanau Attention Mechanism



Tools

In [46]:
import math
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import matplotlib.pyplot as plt

def try_gpu(i=0):
    """返回可用的GPU，否则返回CPU"""
    if torch.cuda.device_count() >= i+1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')


def bleu(preq_seq,label_seq,k=2):
  return sentence_bleu([label_seq.split()],
             preq_seq.split(),
             smoothing_function=SmoothingFunction().method1,
             weights=(1. / k,) * k)


def check_shape(tensor, expected_shape):
    assert tensor.shape == expected_shape, \
        f"Expected shape: {expected_shape}, but got: {tensor.shape}"
    return tensor


def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(3.5, 2.5),
                  sharex=True,sharey=True,squeeze=False):
    num_rows,num_cols=matrices.shape[0],matrices.shape[1]

    # 创建子图网络
    fig,axes=plt.subplots(num_rows,num_cols,figsize=figsize,
                          sharex=True,sharey=True,squeeze=False)
    # 遍历所有矩阵并绘制热力图
    for i in range(num_rows):
        for j in range(num_cols):
            ax=axes[i,j]
            # 将张量转换为numpy数组并分离计算图
            matrix=matrices[i,j].detach().numpy()

            # 使用热力图显示矩阵
            pcm=ax.imshow(matrix,cmap='Reds')

            # 设置坐标轴标签
            if i==num_rows-1:#最后一行显示x轴标签
                ax.set_xlabel(xlabel)
            if j==0:#第一列显示y轴标签
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])

            # 隐藏刻度线
            ax.xaxis.set_ticks_position('none')
            ax.yaxis.set_ticks_position('none')

    # Add a colorbar
    fig.colorbar(pcm, ax=axes.ravel().tolist())
    plt.show()


def masked_softmax(X, valid_lens):
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)

        # 创建掩码：对于每个位置，如果索引小于有效长度则为True，否则为False
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]

        # 将掩码为False的位置设置为指定的值
        X[~mask] = value
        return X

    # 没有提供有效长度，直接返回标准sofemax
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape

        # 处理不同维度的有效长度
        if valid_lens.dim() == 1:
            # 如果是一维,复制到每个位置
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)

        # 应用序列掩码,将无效位置设置为很小的值
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)

        # 应用softmax并恢复原始形状
        return nn.functional.softmax(X.reshape(shape), dim=-1)


class DotProductAttention(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        # 获取特征维度（用于缩放）
        d = queries.shape[-1]

        # 计算注意力分数：Q * K^T / √d
        # bmm: 批量矩阵乘法
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)

        # 应用带掩码的softmax得到注意力权重
        self.attention_weights = masked_softmax(scores, valid_lens)

        # 对注意力权重应用dropout，然后与值相乘得到最终输出
        return torch.bmm(self.dropout(self.attention_weights), values)


class MultiHeadAttention(nn.Module):
    """Multi-head attention."""
    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(num_hiddens, num_hiddens, bias=bias) # Changed LazyLinear to Linear
        self.W_k = nn.Linear(num_hiddens, num_hiddens, bias=bias) # Changed LazyLinear to Linear
        self.W_v = nn.Linear(num_hiddens, num_hiddens, bias=bias) # Changed LazyLinear to Linear - Note: W_v should map to num_hiddens for consistency before final W_o
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias) # Changed LazyLinear to Linear

    def transpose_qkv(self, X):
        """Transposes the last two dimensions of a tensor."""
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        X = X.permute(0, 2, 1, 3)
        return X.reshape(-1, X.shape[2], X.shape[3])

    def transpose_output(self, X):
        """Reverse the operation of transpose_qkv."""
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)


    def forward(self, queries, keys, values, valid_lens):
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        output = self.attention(queries, keys, values, valid_lens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)


class MTFraEng(Dataset):
    def __init__(self, src_sentences, tgt_sentences, src_vocab, tgt_vocab, max_len):
        self.src_sentences = src_sentences
        self.tgt_sentences = tgt_sentences
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src_sentence = self.src_sentences[idx]
        tgt_sentence = self.tgt_sentences[idx]

        src_tokens = [''] + src_sentence.split() + ['']
        tgt_tokens = [''] + tgt_sentence.split() + ['']

        src_indices = [self.src_vocab[token] for token in src_tokens if token in self.src_vocab]
        tgt_indices = [self.tgt_vocab[token] for token in tgt_tokens if token in self.tgt_vocab]

        # 填充或截断到固定长度
        src_indices = self.pad_or_truncate(src_indices, self.max_len, self.src_vocab[''])
        tgt_indices = self.pad_or_truncate(tgt_indices, self.max_len, self.tgt_vocab[''])

        return torch.tensor(src_indices), torch.tensor(tgt_indices)

    def pad_or_truncate(self, sequence, max_len, pad_token):
        if len(sequence) < max_len:
            return sequence + [pad_token] * (max_len - len(sequence))
        else:
            return sequence[:max_len]

    def build_vocab(sentences, min_freq=2):
        tokens = []
        for sentence in sentences:
            tokens.extend(sentence.split())

        token_freq = {}
        for token in tokens:
            token_freq[token] = token_freq.get(token, 0) + 1

        vocab = {'': 0, '': 1, '': 2}
        for token, freq in token_freq.items():
            if freq >= min_freq:
                vocab[token] = len(vocab)

        # Add get_itos method for demonstration
        vocab['get_itos'] = lambda: {v: k for k, v in vocab.items()}

        return vocab

    def build(self, src_sentences, tgt_sentences):
        """Builds the dataset for prediction."""
        src_indices = []
        tgt_indices = []
        for src_sentence, tgt_sentence in zip(src_sentences, tgt_sentences):
            src_tokens = [''] + src_sentence.split() + ['']
            tgt_tokens = [''] + tgt_sentence.split() + ['']
            src_indices.append([self.src_vocab[token] for token in src_tokens if token in self.src_vocab])
            tgt_indices.append([self.tgt_vocab[token] for token in tgt_tokens if token in self.tgt_vocab])

        # Pad or truncate to max_len
        src_indices = [self.pad_or_truncate(indices, self.max_len, self.src_vocab['']) for indices in src_indices]
        tgt_indices = [self.pad_or_truncate(indices, self.max_len, self.tgt_vocab['']) for indices in tgt_indices]

        return torch.tensor(src_indices), torch.tensor(tgt_indices)


class PositionalEncoding(nn.Module):
    """Positional encoding."""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)


def Trainer(model, train_loader, val_loader, num_epochs, lr, device):
    model.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略填充token
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        total_train_loss = 0
        for src, tgt in train_loader:
            src, tgt = src.to(device), tgt.to(device)

            optimizer.zero_grad()
            enc_outputs, _ = model.encoder(src)
            state = model.decoder.init_state(enc_outputs, None)

            # 解码器输入是目标序列去掉最后一个token
            dec_input = tgt[:, :-1]
            # 解码器输出应该预测目标序列去掉第一个token
            dec_output, _ = model.decoder(dec_input, state)

            loss = criterion(dec_output.reshape(-1, dec_output.shape[-1]), tgt[:, 1:].reshape(-1))
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # 验证阶段
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for src, tgt in val_loader:
                src, tgt = src.to(device), tgt.to(device)

                enc_outputs, _ = model.encoder(src)
                state = model.decoder.init_state(enc_outputs, None)
                dec_input = tgt[:, :-1]
                dec_output, _ = model.decoder(dec_input, state)

                loss = criterion(dec_output.reshape(-1, dec_output.shape[-1]), tgt[:, 1:].reshape(-1))
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

    # 绘制损失曲线
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, num_epochs+1), train_losses, label='Train Loss')
    plt.plot(range(1, num_epochs+1), val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    plt.show()


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, tgt_vocab):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_vocab = tgt_vocab

    def forward(self, src, tgt):
        enc_outputs = self.encoder(src)
        dec_state = self.decoder.init_state(enc_outputs, None)
        dec_input = tgt[:, :-1]
        logits, _ = self.decoder(dec_input, dec_state)
        return logits

    def predict_step(self, src_seqs, device, max_len, save_attention_weights=False):
        """Predict the output sequence."""
        self.eval()
        batch_size = len(src_seqs)
        enc_outputs, _ = self.encoder(src_seqs.to(device))
        dec_state = self.decoder.init_state(enc_outputs, None)
        # Start with the  token
        dec_X = torch.tensor([self.tgt_vocab['']] * batch_size,
                             device=device).unsqueeze(1)
        preds = []
        attention_weights = []
        for _ in range(max_len):
            Y, dec_state = self.decoder(dec_X, dec_state)
            # Use the token with the highest predicted probability as the next
            # input to the decoder
            dec_X = Y.argmax(dim=2)
            preds.extend(dec_X.squeeze(1).tolist())
            # Save attention weights (optional)
            if save_attention_weights:
                attention_weights.append(self.decoder.attention_weights)
        return torch.tensor(preds).reshape(batch_size, -1), attention_weights


class AddictiveAttention(nn.Module):
    def __init__(self, num_hiddens, dropout=0.):
        super().__init__()
        self.W_q = nn.Linear(num_hiddens, num_hiddens, bias=False)
        self.W_k = nn.Linear(num_hiddens, num_hiddens, bias=False)
        self.W_v = nn.Linear(num_hiddens, 1, bias=False) # Changed output features to 1 for the final scoring
        self.dropout = nn.Dropout(dropout)
        self.attention_weights = None

    def forward(self, queries, keys, values, valid_lens=None):
        # queries shape: (batch_size, num_hiddens) - assuming query_len = 1
        # keys shape: (batch_size, key_len, num_hiddens)
        # values shape: (batch_size, key_len, num_hiddens)

        # Apply linear transformations
        # queries_transformed shape: (batch_size, num_hiddens)
        queries_transformed = self.W_q(queries)
        # keys_transformed shape: (batch_size, key_len, num_hiddens)
        keys_transformed = self.W_k(keys)

        # Expand dimensions for broadcasting and addition
        # queries_transformed_expanded: (batch_size, 1, num_hiddens)
        # keys_transformed_expanded: (batch_size, key_len, num_hiddens) -> (batch_size, 1, key_len, num_hiddens)
        # We need to expand queries to match key_len dimension and keys to match query_len dimension (which is 1)
        # The sum should be (batch_size, 1, key_len, num_hiddens)
        features = queries_transformed.unsqueeze(1).unsqueeze(2) + keys_transformed.unsqueeze(1)


        # Apply the tanh and W_v layer, then squeeze the last dimension
        # features: (batch_size, 1, key_len, num_hiddens) -> (batch_size, 1, key_len, 1)
        # scores: (batch_size, 1, key_len)
        scores = self.W_v(torch.tanh(features)).squeeze(-1)

        if valid_lens is not None:
             # Need to handle valid_lens masking for additive attention scores (batch_size, 1, key_len)
            # The mask should be applied to the last dimension (key_len)
            mask = torch.arange(keys.shape[1], device=keys.device).unsqueeze(0).unsqueeze(0) >= valid_lens.unsqueeze(-1)
            scores = scores.masked_fill(mask, -1e6)


        self.attention_weights = F.softmax(scores, dim=-1) # Shape: (batch_size, 1, key_len)
        return torch.bmm(self.dropout(self.attention_weights), values) # Result shape: (batch_size, 1, num_hiddens)


class Decoder(nn.Module):
  def init_state(self,enc_outputs,*args):
    raise NotImplementedError
  def forward(self,X,state):
    raise NotImplementedError


class Seq2SeqEncoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.LSTM(embed_size, num_hiddens, num_layers,
                           dropout=dropout, batch_first=True)

    def forward(self, X):
        X = self.embedding(X)
        output, state = self.rnn(X)
        return output, state  # output: (B, T, H), state: (h_n, c_n)


def init_seq2seq(src_vocab_size, tgt_vocab_size, embed_size=256, num_hiddens=256, num_layers=2, dropout=0.1):
    encoder = Seq2SeqEncoder(src_vocab_size, embed_size, num_hiddens, num_layers, dropout)
    decoder = Seq2SeqAttentionDecoder(tgt_vocab_size, embed_size, num_hiddens, num_layers, dropout)
    return Seq2Seq(encoder, decoder)

The base interface for decoders with attention

In [38]:
class AttentionDecoder(Decoder):
    """The base attention-based decoder interface."""
    def __init__(self):
        super().__init__()

    @property
    def attention_weights(self):
        raise NotImplementedError

Implement the RNN decoder

In [39]:
class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0):
        super().__init__()
        self.attention = AddictiveAttention(num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(
            embed_size + num_hiddens, num_hiddens, num_layers,
            dropout=dropout, batch_first=False) # Ensure batch_first is False

        self.dense = nn.LazyLinear(vocab_size)
        # Removed: self.apply(init_seq2seq)

    def init_state(self, enc_outputs, enc_valid_lens):
        outputs, hidden_state = enc_outputs
        # Corrected: Do not permute outputs
        return (outputs, hidden_state, enc_valid_lens)

    def forward(self, X, state):
        enc_outputs, hidden_state, enc_valid_lens = state
        X = self.embedding(X).permute(1, 0, 2) # Permute to (seq_len, batch_size, embed_size) for GRU
        outputs, self._attention_weights = [], []
        for x in X:
            # Pass the hidden state directly as the query
            query = hidden_state[-1] # Shape: (batch_size, num_hiddens)
            context = self.attention(
                query, enc_outputs, enc_outputs, enc_valid_lens) # enc_outputs shape: (batch_size, seq_len, num_hiddens)

            # Reshape context to (1, batch_size, num_hiddens) to match GRU input expectation
            context = context.unsqueeze(0)

            # x shape: (batch_size, embed_size)
            # Unsqueeze x to (1, batch_size, embed_size) to match GRU input expectation
            x = x.unsqueeze(0)

            # Concatenate context and embedded input
            rnn_input = torch.cat((context, x), dim=-1) # Shape: (1, batch_size, num_hiddens + embed_size)

            out, hidden_state = self.rnn(rnn_input, hidden_state) # out shape: (1, batch_size, num_hiddens), hidden_state shape: (num_layers, batch_size, num_hiddens)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)

        # Concatenate outputs from all time steps
        outputs = torch.cat(outputs, dim=0) # Shape: (seq_len, batch_size, num_hiddens)

        # Apply dense layer (LazyLinear expects last dim as input)
        outputs = self.dense(outputs) # Shape: (seq_len, batch_size, vocab_size)

        # Permute back to (batch_size, seq_len, vocab_size)
        outputs = outputs.permute(1, 0, 2)

        return outputs, [enc_outputs, hidden_state,
                                          enc_valid_lens]

    @property
    def attention_weights(self):
        # Assuming attention_weights are collected per step and have shape (batch_size, 1, key_len)
        # Concatenate along the query_len dimension (dim 1 after unsqueezing in attention)
        # The stored weights are already (batch_size, 1, key_len), concatenating a list of these
        # will result in (seq_len, batch_size, 1, key_len).
        # We need to reshape or permute to get (batch_size, seq_len, key_len)
        if self._attention_weights:
            # Stack the list of tensors
            stacked_weights = torch.stack(self._attention_weights, dim=1) # Shape: (batch_size, seq_len, 1, key_len)
            # Squeeze the dimension of size 1
            return stacked_weights.squeeze(2) # Shape: (batch_size, seq_len, key_len)
        else:
            return None

Test the implemented
decoder

In [47]:
vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 7
encoder = Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)
decoder = Seq2SeqAttentionDecoder(vocab_size, embed_size, num_hiddens,
                                  num_layers)
X = torch.zeros((batch_size, num_steps), dtype=torch.long)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
check_shape(output, (batch_size, num_steps, vocab_size))
check_shape(state[0], (batch_size, num_steps, num_hiddens))
check_shape(state[1][0], (batch_size, num_hiddens))

RuntimeError: The size of tensor a (4) must match the size of tensor b (7) at non-singleton dimension 3

In [41]:
# Placeholder code to load and preprocess data

# In a real scenario, you would load your French and English sentences
# from a file or another source.
src_sentences = ["This is a test sentence .", "Another example ."]
tgt_sentences = ["Ceci est une phrase de test .", "Un autre exemple ."]

# Build vocabularies
# Assuming build_vocab method is available in the MTFraEng class or elsewhere
# If not, you would need to implement vocabulary building.
src_vocab = MTFraEng.build_vocab(src_sentences)
tgt_vocab = MTFraEng.build_vocab(tgt_sentences)

# Determine max_len based on your dataset or a predefined value
max_len = 10  # Example value

# Now initialize the dataset correctly
data = MTFraEng(src_sentences, tgt_sentences, src_vocab, tgt_vocab, max_len)

# The rest of the code from cell a73f9cc6 can follow here
embed_size, num_hiddens, num_layers, dropout = 256, 256, 2, 0.2
encoder = Seq2SeqEncoder(
    len(data.src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
    len(data.tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
model = Seq2Seq(encoder, decoder, tgt_vocab=data.tgt_vocab) # Pass tgt_vocab
# The Trainer class is not defined in the provided code. Assuming a Trainer class exists and is imported or defined elsewhere.
# Placeholder for Trainer initialization and fitting.
# trainer = Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)
# trainer.fit(model, data)

print("Placeholder data loaded and model initialized.")

Placeholder data loaded and model initialized.


Training

In [48]:
# data = MTFraEng(batch_size=128) # Removed incorrect initialization

# Assume src_sentences, tgt_sentences, src_vocab, tgt_vocab, and max_len are defined from previous steps
# Re-initialize data, encoder, decoder, and model using the preprocessed data
data = MTFraEng(src_sentences, tgt_sentences, src_vocab, tgt_vocab, max_len)
embed_size, num_hiddens, num_layers, dropout = 256, 256, 2, 0.2
encoder = Seq2SeqEncoder(
    len(data.src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
    len(data.tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
model = Seq2Seq(encoder, decoder, tgt_vocab=data.tgt_vocab)

# Assuming a Trainer class is defined elsewhere or you have implemented a training loop
# Placeholder for Trainer initialization and fitting.
# You would need to replace this with your actual training code.
# For now, I'll use the Trainer function defined in GBEdEM_K7aRG
device = try_gpu()
lr = 0.005 # Assuming a learning rate
num_epochs = 30 # Assuming number of epochs
train_loader = DataLoader(data, batch_size=128, shuffle=True) # Assuming batch size
# Create a dummy validation loader for now
val_loader = DataLoader(data, batch_size=128) # Replace with actual validation data


Trainer(model, train_loader, val_loader, num_epochs, lr, device)

RuntimeError: The size of tensor a (256) must match the size of tensor b (10) at non-singleton dimension 0

Translate a few English sentences

In [49]:
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
preds, _ = model.predict_step(
    data.build(engs, fras), try_gpu(), data.num_steps)
for en, fr, p in zip(engs, fras, preds):
    translation = []
    for token in data.tgt_vocab.to_tokens(p):
        if token == '<eos>':
            break
        translation.append(token)
    print(f'{en} => {translation}, bleu,'
          f'{bleu(" ".join(translation), fr, k=2):.3f}')

AttributeError: 'MTFraEng' object has no attribute 'num_steps'

Visualize the attention weights

In [50]:
_, dec_attention_weights = model.predict_step(
    data.build([engs[-1]], [fras[-1]]), try_gpu(), data.num_steps, True)
attention_weights = torch.cat(
    [step[0][0][0] for step in dec_attention_weights], 0)
attention_weights = attention_weights.reshape((1, 1, -1, data.num_steps))

show_heatmaps(
    attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),
    xlabel='Key positions', ylabel='Query positions')

AttributeError: 'MTFraEng' object has no attribute 'num_steps'