# Attention mechanisms

## Dataset

We will use a dataset from <https://github.com/Charlie9/enron_intent_dataset_verified?tab=readme-ov-file>. This dataset consists of sentences from emails sent between employees of the Enron corporation. Each sentence has been manually labeled regarding whether it contains a request or does not contain a request. We will train an attention model to classify sentences as "request" or "no request" sentences.

In [None]:
def read_intent_file(file_path: str) -> list[str]:
    with open(file_path, 'r') as file:
        lines = file.readlines()
    return [line.strip() for line in lines]

# Read positive and negative intent files
pos_intent_path = "data/Enron/intent_pos"
neg_intent_path = "data/Enron/intent_neg"

pos_intent_sentences = read_intent_file(pos_intent_path)
neg_intent_sentences = read_intent_file(neg_intent_path)

## Tokenization

We need to parse the sentences

In [None]:
from torchtext.data.utils import get_tokenizer

tokenizer = get_tokenizer('basic_english')
tokens = tokenizer("Please send me the report by EOD.")
tokens

In [None]:
from torchtext.vocab import build_vocab_from_iterator

def yield_tokens(data_iter):
    for txt in data_iter:
        yield tokenizer(txt)

all_sentences = pos_intent_sentences + neg_intent_sentences

vocab = build_vocab_from_iterator(yield_tokens(all_sentences), specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab['<unk>'])

In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence

def encode_batch(batch):
    # batch: list of (raw_sentence, label)
    token_ids = [torch.tensor(vocab(tokenizer(txt)), dtype=torch.long)
                 for txt, _ in batch]
    padded = pad_sequence(token_ids, batch_first=True,
                          padding_value=vocab['<pad>'])
    labels = torch.tensor([lbl for _, lbl in batch], dtype=torch.long)
    return padded, labels

In [None]:
from attention import EnronRequestDataset
from torch.utils.data import DataLoader

from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    token_ids_list, labels = zip(*batch)
    # pad to batch max‐len
    src = pad_sequence(token_ids_list,
                       batch_first=True,
                       padding_value=vocab['<pad>'])
    labels = torch.tensor(labels, dtype=torch.long)
    # pad_mask: True for PAD tokens, False for real tokens
    pad_mask = src == vocab['<pad>']
    return src, labels, pad_mask

# — now wrap in DataLoader —
sentences = pos_intent_sentences + neg_intent_sentences
labels = [1] * len(pos_intent_sentences) + [0] * len(neg_intent_sentences)

dataset = EnronRequestDataset(sentences, labels, vocab, tokenizer)
loader  = DataLoader(dataset,
                     batch_size=32,
                     shuffle=True,
                     collate_fn=collate_fn,
                     num_workers=0,
                     pin_memory=True)


In [None]:
# import reload
import importlib
import attention
importlib.reload(attention)

import torch.nn as nn

model = attention.RequestClassifier(len(vocab))

opt = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

epochs = 50

for epoch in range(epochs):
    for src, labels, pad_mask in loader:
        logits = model(src, src_key_padding_mask=pad_mask)
        loss   = loss_fn(logits, labels)

        if torch.isnan(logits).any():
            print("🛑 NaN in logits!"); break

        loss.backward()
        opt.step(); opt.zero_grad()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")


## Evaluation

In [None]:
# Evaluate single sentence
import torch
import torch.nn.functional as F

def predict_sentence(model, sentence, vocab, tokenizer, device):
    model.eval()
    with torch.no_grad():
        # 1) tokenize & numericalize
        tokens = tokenizer(sentence)
        ids    = torch.tensor(vocab(tokens), dtype=torch.long).unsqueeze(0).to(device)
        # 2) build padding mask (True==pad for Transformer)
        pad_idx = vocab['<pad>']
        mask    = ids != pad_idx
        # 3) forward
        logits = model(ids, src_key_padding_mask=~mask)
        probs  = F.softmax(logits, dim=-1)
        pred   = probs.argmax(dim=-1).item()
    return pred, probs.squeeze().cpu().tolist()

# — Example usage —
device    = 'cuda' if torch.cuda.is_available() else 'cpu'
model     = model.to(device)            # your trained RequestClassifier

test_sentences = [
    "Please send me the report by EOD.",
    "I need the report ASAP.",
    "Can you send me the report?",
    "You used to have a white cat.",
    "The weather is nice today.",
    "Knut is giving a lecture.",
    "Knut, please give the lecture.",
    "This is a test",
    "Cats are blue. I want something blue. Get me a blue cat.",
    "I am requesting a book.",
    "I am requesting a book, please.",
    "Please, I am requesting a book.",
    "Please, please, I am requesting a book.",
]

for sentence in test_sentences:
    pred, probs = predict_sentence(model, sentence, vocab, tokenizer, device)

    label_map = {0: "no_request", 1: "request"}
    print(f"→ {sentence!r}")
    print(f"Prediction: {label_map[pred]} (P(request)={probs[1]:.4f})")


## Visualization

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import utility
importlib.reload(utility)

def plot_attention(tokens, attn: np.ndarray):
    # tokens: List[str], attn: [S,S] NumPy array
    fig, ax = plt.subplots()
    cax = ax.matshow(attn)             # one distinct plot, no seaborn
    fig.colorbar(cax)
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=90)
    ax.set_yticklabels(tokens)
    plt.xlabel("Key positions"); plt.ylabel("Query positions")
    plt.show()


text = "Please help me."
tokens = tokenizer(text)
ids    = torch.tensor([vocab(tokens)], dtype=torch.long)
padm   = ids == vocab['<pad>']

model.eval()
with torch.no_grad():
    logits, attn = model(ids, src_key_padding_mask=padm, return_attn=True)
    # attn: [1, seq_len, seq_len]  (since single head)

    attn_np = np.asarray(attn[0].detach().cpu().tolist())

plot_attention(tokens, attn_np)

col_importance = [ sum(row[j] for row in attn_np) for j in range(len(tokens)) ]

# pair and sort
token_scores = zip(tokens, col_importance)

print("Top tokens by attention paid to them:")
for token, score in token_scores:
    print(f"  {token:>10} → {score:.3f}")

utility.display_tokens_with_alpha(tokens, col_importance)


Now let's plot a number of sentences, and see which words are most important.

In [None]:
test_sentences = [
    "Please send me the report by EOD.",
    "I need the report ASAP.",
    "Please help me out with the report.",
]

for sentence in test_sentences:
    tokens = tokenizer(sentence)
    ids    = torch.tensor([vocab(tokens)], dtype=torch.long)
    padm   = ids == vocab['<pad>']

    model.eval()
    with torch.no_grad():
        logits, attn = model(ids, src_key_padding_mask=padm, return_attn=True)
        # attn: [1, seq_len, seq_len]  (since single head)

        attn_np = np.asarray(attn[0].detach().cpu().tolist())

    col_importance = [ sum(row[j] for row in attn_np) for j in range(len(tokens)) ]

    utility.display_tokens_with_alpha(tokens, col_importance)
    print(f"Prediction: {label_map[logits.argmax(dim=-1).item()]} ; P(request)={F.softmax(logits, dim=-1)[0][1]:.4f}")
    print("")
